diff --git a/CMakeLists.txt b/CMakeLists.txt index 732b2dea3..21c10f99a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,12 +99,12 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_catalog_catalog_entry.cpp src/duckdb/ub_src_catalog_catalog_entry_dependency.cpp src/duckdb/ub_src_catalog_default.cpp - src/duckdb/ub_src_common_adbc.cpp + src/duckdb/src/common/adbc/adbc.cpp src/duckdb/ub_src_common_adbc_nanoarrow.cpp src/duckdb/ub_src_common.cpp src/duckdb/ub_src_common_arrow_appender.cpp src/duckdb/ub_src_common_arrow.cpp - src/duckdb/ub_src_common_crypto.cpp + src/duckdb/src/common/crypto/md5.cpp src/duckdb/ub_src_common_enums.cpp src/duckdb/ub_src_common_exception.cpp src/duckdb/ub_src_common_multi_file.cpp @@ -113,12 +113,12 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_common_row_operations.cpp src/duckdb/ub_src_common_serializer.cpp src/duckdb/ub_src_common_sort.cpp - src/duckdb/ub_src_common_sorting.cpp src/duckdb/ub_src_common_tree_renderer.cpp src/duckdb/ub_src_common_types.cpp src/duckdb/ub_src_common_types_column.cpp src/duckdb/ub_src_common_types_row.cpp - src/duckdb/ub_src_common_value_operations.cpp + src/duckdb/ub_src_common_types_variant.cpp + src/duckdb/src/common/value_operations/comparison_operations.cpp src/duckdb/src/common/vector_operations/boolean_operators.cpp src/duckdb/src/common/vector_operations/comparison_operators.cpp src/duckdb/src/common/vector_operations/generators.cpp @@ -136,13 +136,13 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_execution_nested_loop_join.cpp src/duckdb/ub_src_execution_operator_aggregate.cpp src/duckdb/ub_src_execution_operator_csv_scanner_buffer_manager.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_encode.cpp + src/duckdb/src/execution/operator/csv_scanner/encode/csv_encoder.cpp src/duckdb/ub_src_execution_operator_csv_scanner_scanner.cpp src/duckdb/ub_src_execution_operator_csv_scanner_sniffer.cpp src/duckdb/ub_src_execution_operator_csv_scanner_state_machine.cpp src/duckdb/ub_src_execution_operator_csv_scanner_table_function.cpp src/duckdb/ub_src_execution_operator_csv_scanner_util.cpp - src/duckdb/ub_src_execution_operator_filter.cpp + src/duckdb/src/execution/operator/filter/physical_filter.cpp src/duckdb/ub_src_execution_operator_helper.cpp src/duckdb/ub_src_execution_operator_join.cpp src/duckdb/ub_src_execution_operator_order.cpp @@ -154,20 +154,21 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_execution_physical_plan.cpp src/duckdb/ub_src_execution_sample.cpp src/duckdb/ub_src_function_aggregate_distributive.cpp - src/duckdb/ub_src_function_aggregate.cpp + src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp src/duckdb/ub_src_function.cpp src/duckdb/ub_src_function_cast.cpp - src/duckdb/ub_src_function_cast_union.cpp + src/duckdb/src/function/cast/union/from_struct.cpp src/duckdb/ub_src_function_cast_variant.cpp src/duckdb/ub_src_function_pragma.cpp src/duckdb/ub_src_function_scalar_compressed_materialization.cpp src/duckdb/ub_src_function_scalar.cpp - src/duckdb/ub_src_function_scalar_date.cpp + src/duckdb/src/function/scalar/date/strftime.cpp src/duckdb/ub_src_function_scalar_generic.cpp + src/duckdb/src/function/scalar/geometry/geometry_functions.cpp src/duckdb/ub_src_function_scalar_list.cpp - src/duckdb/ub_src_function_scalar_map.cpp + src/duckdb/src/function/scalar/map/map_contains.cpp src/duckdb/ub_src_function_scalar_operator.cpp - src/duckdb/ub_src_function_scalar_sequence.cpp + src/duckdb/src/function/scalar/sequence/nextval.cpp src/duckdb/ub_src_function_scalar_string.cpp src/duckdb/ub_src_function_scalar_string_regexp.cpp src/duckdb/ub_src_function_scalar_struct.cpp @@ -176,7 +177,8 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_function_table_arrow.cpp src/duckdb/ub_src_function_table.cpp src/duckdb/ub_src_function_table_system.cpp - src/duckdb/ub_src_function_table_version.cpp + src/duckdb/src/function/table/version/pragma_version.cpp + src/duckdb/src/function/variant/variant_shredding.cpp src/duckdb/ub_src_function_window.cpp src/duckdb/ub_src_logging.cpp src/duckdb/ub_src_main.cpp @@ -189,14 +191,14 @@ set(DUCKDB_SRC_FILES src/duckdb/src/main/extension/extension_install.cpp src/duckdb/src/main/extension/extension_load.cpp src/duckdb/src/main/extension/extension_loader.cpp - src/duckdb/ub_src_main_http.cpp + src/duckdb/src/main/http/http_util.cpp src/duckdb/ub_src_main_relation.cpp src/duckdb/ub_src_main_secret.cpp src/duckdb/ub_src_main_settings.cpp src/duckdb/ub_src_optimizer.cpp src/duckdb/ub_src_optimizer_compressed_materialization.cpp src/duckdb/ub_src_optimizer_join_order.cpp - src/duckdb/ub_src_optimizer_matcher.cpp + src/duckdb/src/optimizer/matcher/expression_matcher.cpp src/duckdb/ub_src_optimizer_pullup.cpp src/duckdb/ub_src_optimizer_pushdown.cpp src/duckdb/ub_src_optimizer_rule.cpp @@ -210,7 +212,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_parser_query_node.cpp src/duckdb/ub_src_parser_statement.cpp src/duckdb/ub_src_parser_tableref.cpp - src/duckdb/ub_src_parser_transform_constraint.cpp + src/duckdb/src/parser/transform/constraint/transform_constraint.cpp src/duckdb/ub_src_parser_transform_expression.cpp src/duckdb/ub_src_parser_transform_helpers.cpp src/duckdb/ub_src_parser_transform_statement.cpp @@ -238,6 +240,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_storage_serialization.cpp src/duckdb/ub_src_storage_statistics.cpp src/duckdb/ub_src_storage_table.cpp + src/duckdb/ub_src_storage_table_variant.cpp src/duckdb/ub_src_transaction.cpp src/duckdb/src/verification/copied_statement_verifier.cpp src/duckdb/src/verification/deserialized_statement_verifier.cpp @@ -356,12 +359,12 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp src/duckdb/ub_extension_core_functions_scalar_string.cpp - src/duckdb/ub_extension_core_functions_scalar_bit.cpp - src/duckdb/ub_extension_core_functions_scalar_operators.cpp - src/duckdb/ub_extension_core_functions_scalar_enum.cpp + src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp + src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp + src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp src/duckdb/ub_extension_core_functions_scalar_map.cpp src/duckdb/ub_extension_core_functions_scalar_random.cpp - src/duckdb/ub_extension_core_functions_scalar_math.cpp + src/duckdb/extension/core_functions/scalar/math/numeric.cpp src/duckdb/ub_extension_core_functions_scalar_union.cpp src/duckdb/ub_extension_core_functions_scalar_generic.cpp src/duckdb/ub_extension_core_functions_scalar_struct.cpp @@ -377,10 +380,13 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/parquet/parquet_timestamp.cpp src/duckdb/extension/parquet/parquet_float16.cpp src/duckdb/extension/parquet/parquet_statistics.cpp + src/duckdb/extension/parquet/parquet_shredding.cpp + src/duckdb/extension/parquet/parquet_geometry.cpp src/duckdb/extension/parquet/parquet_multi_file_info.cpp src/duckdb/extension/parquet/column_reader.cpp - src/duckdb/extension/parquet/geo_parquet.cpp + src/duckdb/extension/parquet/parquet_field_id.cpp src/duckdb/extension/parquet/parquet_extension.cpp + src/duckdb/extension/parquet/parquet_column_schema.cpp src/duckdb/extension/parquet/column_writer.cpp src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp src/duckdb/extension/parquet/serialize_parquet.cpp @@ -389,6 +395,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_extension_parquet_reader.cpp src/duckdb/ub_extension_parquet_reader_variant.cpp src/duckdb/ub_extension_parquet_writer.cpp + src/duckdb/ub_extension_parquet_writer_variant.cpp src/duckdb/third_party/parquet/parquet_types.cpp src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp src/duckdb/third_party/thrift/thrift/transport/TTransportException.cpp @@ -453,7 +460,8 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/json/json_common.cpp src/duckdb/extension/json/json_deserializer.cpp src/duckdb/extension/json/json_serializer.cpp - src/duckdb/ub_extension_json_json_functions.cpp) + src/duckdb/ub_extension_json_json_functions.cpp + src/duckdb/generated_extension_loader_package_build.cpp) set(JEMALLOC_SRC_FILES src/duckdb/extension/jemalloc/jemalloc_extension.cpp diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp index c2cfd61f8..6e55010e2 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp @@ -272,7 +272,7 @@ unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunctio function = GetAverageAggregate(decimal_type.InternalType()); function.name = "avg"; function.arguments[0] = decimal_type; - function.return_type = LogicalType::DOUBLE; + function.SetReturnType(LogicalType::DOUBLE); return make_uniq( Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp index 40b426390..9e478dedd 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp @@ -90,7 +90,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) AggregateFunction::StateCombine, AggregateFunction::StateFinalize, ApproxCountDistinctSimpleUpdateFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp index d2bdfbe54..51c2becfa 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -15,7 +15,7 @@ namespace duckdb { namespace { struct ArgMinMaxStateBase { - ArgMinMaxStateBase() : is_initialized(false), arg_null(false) { + ArgMinMaxStateBase() : is_initialized(false), arg_null(false), val_null(false) { } template @@ -34,6 +34,7 @@ struct ArgMinMaxStateBase { bool is_initialized; bool arg_null; + bool val_null; }; // Out-of-line specialisations @@ -81,7 +82,7 @@ struct ArgMinMaxState : public ArgMinMaxStateBase { } }; -template +template struct ArgMinMaxBase { template static void Initialize(STATE &state) { @@ -94,25 +95,48 @@ struct ArgMinMaxBase { } template - static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, + static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, const bool y_null, AggregateInputData &aggregate_input_data) { - if (IGNORE_NULL) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL) { STATE::template AssignValue(state.arg, x, aggregate_input_data); STATE::template AssignValue(state.value, y, aggregate_input_data); } else { state.arg_null = x_null; + state.val_null = y_null; if (!state.arg_null) { STATE::template AssignValue(state.arg, x, aggregate_input_data); } - STATE::template AssignValue(state.value, y, aggregate_input_data); + if (!state.val_null) { + STATE::template AssignValue(state.value, y, aggregate_input_data); + } } } template static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) { + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); if (!state.is_initialized) { - if (IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) { - Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + binary.left_mask.RowIsValid(binary.lidx) && binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + binary.right_mask.RowIsValid(binary.ridx)) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); + state.is_initialized = true; + return; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL) { + Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx), + !binary.right_mask.RowIsValid(binary.ridx), binary.input); state.is_initialized = true; } } else { @@ -122,8 +146,15 @@ struct ArgMinMaxBase { template static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) { - if ((IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) && COMPARATOR::Operation(y_data, state.value)) { - Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), binary.input); + D_ASSERT(binary.input.bind_data); + const auto &bind_data = binary.input.bind_data->Cast(); + + if (binary.right_mask.RowIsValid(binary.ridx) && + (state.val_null || COMPARATOR::Operation(y_data, state.value))) { + if (bind_data.null_handling != ArgMinMaxNullHandling::IGNORE_ANY_NULL || + binary.left_mask.RowIsValid(binary.lidx)) { + Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), false, binary.input); + } } } @@ -132,8 +163,10 @@ struct ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - Assign(target, source.arg, source.value, source.arg_null, aggregate_input_data); + + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + Assign(target, source.arg, source.value, source.arg_null, false, aggregate_input_data); target.is_initialized = true; } } @@ -148,17 +181,20 @@ struct ArgMinMaxBase { } static bool IgnoreNull() { - return IGNORE_NULL; + return false; } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); } function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; @@ -186,12 +222,14 @@ struct GenericArgMinMaxState { } }; -template -struct VectorArgMinMaxBase : ArgMinMaxBase { +template +struct VectorArgMinMaxBase : ArgMinMaxBase { template static void Update(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggregate_input_data.bind_data); + const auto &bind_data = aggregate_input_data.bind_data->Cast(); + auto &arg = inputs[0]; UnifiedVectorFormat adata; arg.ToUnifiedFormat(count, adata); @@ -213,21 +251,36 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { auto states = UnifiedVectorFormat::GetData(sdata); for (idx_t i = 0; i < count; i++) { - const auto bidx = bdata.sel->get_index(i); - if (!bdata.validity.RowIsValid(bidx)) { - continue; - } - const auto bval = bys[bidx]; + const auto sidx = sdata.sel->get_index(i); + auto &state = *states[sidx]; const auto aidx = adata.sel->get_index(i); const auto arg_null = !adata.validity.RowIsValid(aidx); - if (IGNORE_NULL && arg_null) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && arg_null) { continue; } - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { + const auto bidx = bdata.sel->get_index(i); + + if (!bdata.validity.RowIsValid(bidx)) { + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL && !state.is_initialized) { + state.val_null = true; + if (!arg_null) { + state.is_initialized = true; + if (&state == last_state) { + assign_count--; + } + assign_sel[assign_count++] = UnsafeNumericCast(i); + last_state = &state; + } + } + continue; + } + + const auto bval = bys[bidx]; + + if (!state.is_initialized || state.val_null || COMPARATOR::template Operation(bval, state.value)) { STATE::template AssignValue(state.value, bval, aggregate_input_data); state.arg_null = arg_null; // micro-adaptivity: it is common we overwrite the same state repeatedly @@ -270,8 +323,12 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { if (!source.is_initialized) { return; } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value, aggregate_input_data); + if (!target.is_initialized || target.val_null || + (!source.val_null && COMPARATOR::Operation(source.value, target.value))) { + target.val_null = source.val_null; + if (!target.val_null) { + STATE::template AssignValue(target.value, source.value, aggregate_input_data); + } target.arg_null = source.arg_null; if (!target.arg_null) { STATE::template AssignValue(target.arg, source.arg, aggregate_input_data); @@ -290,38 +347,56 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { } } + template static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); } function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } }; template -AggregateFunction GetGenericArgMinMaxFunction() { +bind_aggregate_function_t GetBindFunction(const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + return OP::template Bind; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + return OP::template Bind; + default: + return OP::template Bind; + } +} + +template +AggregateFunction GetGenericArgMinMaxFunction(const ArgMinMaxNullHandling null_handling) { using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction( {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); } template -AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; + auto bind = GetBindFunction(null_handling); return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateVoidFinalize, nullptr, bind, AggregateFunction::StateDestroy); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; return function; @@ -330,18 +405,19 @@ AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetVectorArgMinMaxFunctionInternal(by_type, type); + return GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max aggregate"); } @@ -356,30 +432,32 @@ const vector ArgMaxByTypes() { } template -void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { #ifndef DUCKDB_SMALLER_BINARY using STATE = ArgMinMaxState; auto function = AggregateFunction::BinaryAggregate( type, by_type, type); if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { - function.destructor = AggregateFunction::StateDestroy; + function.SetStateDestructorCallback(AggregateFunction::StateDestroy); } - function.bind = OP::Bind; + function.SetBindCallback(GetBindFunction(null_handling)); #else - auto function = GetGenericArgMinMaxFunction(); + auto function = GetGenericArgMinMaxFunction(null_handling); function.arguments = {type, by_type}; function.return_type = type; #endif @@ -388,18 +466,19 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const #ifndef DUCKDB_SMALLER_BINARY template -AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type, + const ArgMinMaxNullHandling null_handling) { switch (by_type.InternalType()) { case PhysicalType::INT32: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::INT128: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::DOUBLE: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); case PhysicalType::VARCHAR: - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); default: throw InternalException("Unimplemented arg_min/arg_max by aggregate"); } @@ -407,37 +486,38 @@ AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const Logic #endif template -void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { +void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, ArgMinMaxNullHandling null_handling) { auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { #ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type, null_handling)); #else - fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type)); + fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type, null_handling)); #endif } } template -AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type) { +AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type, + ArgMinMaxNullHandling null_handling) { D_ASSERT(type.id() == LogicalTypeId::DECIMAL); #ifndef DUCKDB_SMALLER_BINARY switch (type.InternalType()) { case PhysicalType::INT16: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT32: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); case PhysicalType::INT64: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); default: - return GetArgMinMaxFunctionBy(by_type, type); + return GetArgMinMaxFunctionBy(by_type, type, null_handling); } #else - return GetArgMinMaxFunctionInternal(by_type, type); + return GetArgMinMaxFunctionInternal(by_type, type, null_handling); #endif } -template +template unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function, vector> &arguments) { auto decimal_type = arguments[0]->return_type; @@ -469,51 +549,69 @@ unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateF } auto name = std::move(function.name); - function = GetDecimalArgMinMaxFunction(by_type, decimal_type); + function = GetDecimalArgMinMaxFunction(by_type, decimal_type, NULL_HANDLING); function.name = std::move(name); - function.return_type = decimal_type; - return nullptr; + function.SetReturnType(decimal_type); + + auto function_data = make_uniq(NULL_HANDLING); + return unique_ptr(std::move(function_data)); } template -void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type) { - fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); +void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type, + const ArgMinMaxNullHandling null_handling) { + switch (null_handling) { + case ArgMinMaxNullHandling::IGNORE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ARG_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + case ArgMinMaxNullHandling::HANDLE_ANY_NULL: + fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + BindDecimalArgMinMax)); + break; + } } template -void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { - fun.AddFunction(GetGenericArgMinMaxFunction()); +void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + fun.AddFunction(GetGenericArgMinMaxFunction(null_handling)); } -template -void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { - using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; +template +void AddArgMinMaxFunctions(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) { + using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; #ifndef DUCKDB_SMALLER_BINARY - using OP = ArgMinMaxBase; - using VECTOR_OP = VectorArgMinMaxBase; + using OP = ArgMinMaxBase; + using VECTOR_OP = VectorArgMinMaxBase; #else using OP = GENERIC_VECTOR_OP; using VECTOR_OP = GENERIC_VECTOR_OP; #endif - AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); - AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); - AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); - AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); - AddArgMinMaxFunctionBy(fun, LogicalType::DATE); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); - AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); + AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::DATE, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ, null_handling); + AddArgMinMaxFunctionBy(fun, LogicalType::BLOB, null_handling); auto by_types = ArgMaxByTypes(); for (const auto &by_type : by_types) { - AddDecimalArgMinMaxFunctionBy(fun, by_type); + AddDecimalArgMinMaxFunctionBy(fun, by_type, null_handling); } - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); + AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY, null_handling); // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest - AddGenericArgMinMaxFunction(fun); + AddGenericArgMinMaxFunction(fun, null_handling); } //------------------------------------------------------------------------------ @@ -547,6 +645,8 @@ class ArgMinMaxNState { template void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { + D_ASSERT(aggr_input.bind_data); + const auto &bind_data = aggr_input.bind_data->Cast(); auto &val_vector = inputs[0]; auto &arg_vector = inputs[1]; @@ -560,8 +660,8 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); - STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last); + STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); @@ -571,9 +671,16 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp for (idx_t i = 0; i < count; i++) { const auto arg_idx = arg_format.sel->get_index(i); const auto val_idx = val_format.sel->get_index(i); - if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { + + if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && + (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx))) { + continue; + } + if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL && + !val_format.validity.RowIsValid(val_idx)) { continue; } + const auto state_idx = state_format.sel->get_index(i); auto &state = *states[state_idx]; @@ -610,13 +717,13 @@ void SpecializeArgMinMaxNFunction(AggregateFunction &function) { using STATE = ArgMinMaxNState; using OP = MinMaxNOperation; - function.state_size = AggregateFunction::StateSize; - function.initialize = AggregateFunction::StateInitialize; - function.combine = AggregateFunction::StateCombine; - function.destructor = AggregateFunction::StateDestroy; + function.SetStateSizeCallback(AggregateFunction::StateSize); + function.SetStateInitCallback(AggregateFunction::StateInitialize); + function.SetStateCombineCallback(AggregateFunction::StateCombine); + function.SetStateDestructorCallback(AggregateFunction::StateDestroy); - function.finalize = MinMaxNOperation::Finalize; - function.update = ArgMinMaxNUpdate; + function.SetStateFinalizeCallback(MinMaxNOperation::Finalize); + function.SetStateUpdateCallback(ArgMinMaxNUpdate); } template @@ -671,7 +778,76 @@ void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, } } -template +template +void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) { + using STATE = ArgMinMaxNState; + using OP = MinMaxNOperation; + + function.SetStateSizeCallback(AggregateFunction::StateSize); + function.SetStateInitCallback(AggregateFunction::StateInitialize); + function.SetStateCombineCallback(AggregateFunction::StateCombine); + function.SetStateDestructorCallback(AggregateFunction::StateDestroy); + function.SetStateFinalizeCallback(MinMaxNOperation::Finalize); + function.SetStateUpdateCallback(ArgMinMaxNUpdate); +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) { + switch (arg_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, COMPARATOR>(function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(function); + break; + } +} + +template +void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { + switch (val_type) { +#ifndef DUCKDB_SMALLER_BINARY + case PhysicalType::VARCHAR: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + case PhysicalType::INT32: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::INT64: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::FLOAT: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; + case PhysicalType::DOUBLE: + SpecializeArgMinMaxNullNFunction, NULLS_LAST, COMPARATOR>(arg_type, + function); + break; +#endif + default: + SpecializeArgMinMaxNullNFunction(arg_type, function); + break; + } +} + +template unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, vector> &arguments) { for (auto &arg : arguments) { @@ -682,19 +858,24 @@ unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunctio const auto val_type = arguments[0]->return_type.InternalType(); const auto arg_type = arguments[1]->return_type.InternalType(); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); // Specialize the function based on the input types - SpecializeArgMinMaxNFunction(val_type, arg_type, function); + auto function_data = make_uniq(NULL_HANDLING, NULLS_LAST); + if (NULL_HANDLING != ArgMinMaxNullHandling::IGNORE_ANY_NULL) { + SpecializeArgMinMaxNullNFunction(val_type, arg_type, function); + } else { + SpecializeArgMinMaxNFunction(val_type, arg_type, function); + } - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; + return unique_ptr(std::move(function_data)); } -template +template void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, ArgMinMaxNBind); + nullptr, ArgMinMaxNBind); return set.AddFunction(function); } @@ -707,27 +888,41 @@ void AddArgMinMaxNFunction(AggregateFunctionSet &set) { AggregateFunctionSet ArgMinFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMaxFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } AggregateFunctionSet ArgMinNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); return fun; } AggregateFunctionSet ArgMaxNullFun::GetFunctions() { AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL); + return fun; +} + +AggregateFunctionSet ArgMinNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); + return fun; +} + +AggregateFunctionSet ArgMaxNullsLastFun::GetFunctions() { + AggregateFunctionSet fun; + AddArgMinMaxFunctions(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL); + AddArgMinMaxNFunction(fun); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp index 168d3a539..fccfd0ac8 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp @@ -166,7 +166,6 @@ struct BitStringBitwiseOperation : public BitwiseOperation { }; struct BitStringAndOperation : public BitStringBitwiseOperation { - template static void Execute(STATE &state, INPUT_TYPE input) { Bit::BitwiseAnd(input, state.value, state.value); @@ -174,7 +173,6 @@ struct BitStringAndOperation : public BitStringBitwiseOperation { }; struct BitStringOrOperation : public BitStringBitwiseOperation { - template static void Execute(STATE &state, INPUT_TYPE input) { Bit::BitwiseOr(input, state.value, state.value); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp index fad7550d8..595760985 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -235,7 +235,6 @@ idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) { unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input) { - if (NumericStats::HasMinMax(input.child_stats[0])) { auto &bind_agg_data = input.bind_data->Cast(); bind_agg_data.min = NumericStats::Min(input.child_stats[0]); @@ -264,13 +263,14 @@ void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &typ auto function = AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( type, LogicalType::BIT); - function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData' - function.serialize = BitstringAggBindData::Serialize; - function.deserialize = BitstringAggBindData::Deserialize; - function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData + function.SetBindCallback(BindBitstringAgg); // create new a 'BitstringAggBindData' + function.SetSerializeCallback(BitstringAggBindData::Serialize); + function.SetDeserializeCallback(BitstringAggBindData::Deserialize); + function.SetStatisticsCallback( + BitstringPropagateStats); // stores min and max from column stats in BitstringAggBindData bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring function.arguments = {type, type, type}; - function.statistics = nullptr; // min and max are provided as arguments + function.SetStatisticsCallback(nullptr); // min and max are provided as arguments bitstring_agg.AddFunction(function); } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp index c8c8422d6..89646f3b1 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp @@ -98,16 +98,16 @@ struct BoolOrFunFunction { AggregateFunction BoolOrFun::GetFunction() { auto fun = AggregateFunction::UnaryAggregate( LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); return fun; } AggregateFunction BoolAndFun::GetFunction() { auto fun = AggregateFunction::UnaryAggregate( LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp index aa551eca5..d1ca6b694 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp @@ -106,7 +106,7 @@ AggregateFunction KurtosisFun::GetFunction() { auto result = AggregateFunction::UnaryAggregate>( LogicalType::DOUBLE, LogicalType::DOUBLE); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } @@ -114,7 +114,7 @@ AggregateFunction KurtosisPopFun::GetFunction() { auto result = AggregateFunction::UnaryAggregate>( LogicalType::DOUBLE, LogicalType::DOUBLE); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp index ddbecbf28..8414a5921 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp @@ -120,6 +120,12 @@ unique_ptr StringAggBind(ClientContext &context, AggregateFunction return make_uniq(","); } D_ASSERT(arguments.size() == 2); + // Check if any argument is of UNKNOWN type (parameter not yet bound) + for (auto &arg : arguments) { + if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + } if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); } @@ -160,8 +166,8 @@ AggregateFunctionSet StringAggFun::GetFunctions() { AggregateFunction::StateCombine, AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, StringAggBind); - string_agg_param.serialize = StringAggSerialize; - string_agg_param.deserialize = StringAggDeserialize; + string_agg_param.SetSerializeCallback(StringAggSerialize); + string_agg_param.SetDeserializeCallback(StringAggDeserialize); string_agg.AddFunction(string_agg_param); string_agg_param.arguments.emplace_back(LogicalType::VARCHAR); string_agg.AddFunction(string_agg_param); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp index bfea19644..7746cadf3 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp @@ -84,7 +84,7 @@ void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) { - function.return_type = deserializer.Get(); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -94,20 +94,20 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( LogicalType::INTEGER, LogicalType::HUGEINT); function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.bind = SumNoOverflowBind; - function.serialize = SumNoOverflowSerialize; - function.deserialize = SumNoOverflowDeserialize; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + function.SetBindCallback(SumNoOverflowBind); + function.SetSerializeCallback(SumNoOverflowSerialize); + function.SetDeserializeCallback(SumNoOverflowDeserialize); return function; } case PhysicalType::INT64: { auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( LogicalType::BIGINT, LogicalType::HUGEINT); function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.bind = SumNoOverflowBind; - function.serialize = SumNoOverflowSerialize; - function.deserialize = SumNoOverflowDeserialize; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + function.SetBindCallback(SumNoOverflowBind); + function.SetSerializeCallback(SumNoOverflowSerialize); + function.SetDeserializeCallback(SumNoOverflowDeserialize); return function; } default: @@ -118,8 +118,8 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { AggregateFunction GetSumAggregateNoOverflowDecimal() { AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind); - aggr.serialize = SumNoOverflowSerialize; - aggr.deserialize = SumNoOverflowDeserialize; + aggr.SetSerializeCallback(SumNoOverflowSerialize); + aggr.SetDeserializeCallback(SumNoOverflowDeserialize); return aggr; } @@ -163,13 +163,13 @@ AggregateFunction GetSumAggregate(PhysicalType type) { case PhysicalType::BOOL: { auto function = AggregateFunction::UnaryAggregate, bool, hugeint_t, IntegerSumOperation>( LogicalType::BOOLEAN, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return function; } case PhysicalType::INT16: { auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( LogicalType::SMALLINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return function; } @@ -177,23 +177,23 @@ AggregateFunction GetSumAggregate(PhysicalType type) { auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, SumToHugeintOperation>( LogicalType::INTEGER, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetStatisticsCallback(SumPropagateStats); + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return function; } case PhysicalType::INT64: { auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, SumToHugeintOperation>( LogicalType::BIGINT, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetStatisticsCallback(SumPropagateStats); + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return function; } case PhysicalType::INT128: { auto function = AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( LogicalType::HUGEINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return function; } default: @@ -207,8 +207,8 @@ unique_ptr BindDecimalSum(ClientContext &context, AggregateFunctio function = GetSumAggregate(decimal_type.InternalType()); function.name = "sum"; function.arguments[0] = decimal_type; - function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetReturnType(LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type))); + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return nullptr; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp index 641a5010e..2f5e15312 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -395,10 +395,10 @@ unique_ptr ApproxTopKBind(ClientContext &context, AggregateFunctio } } if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { - function.update = ApproxTopKUpdate; - function.finalize = ApproxTopKFinalize; + function.SetStateUpdateCallback(ApproxTopKUpdate); + function.SetStateFinalizeCallback(ApproxTopKFinalize); } - function.return_type = LogicalType::LIST(arguments[0]->return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); return nullptr; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp index 35336383b..0896eccb8 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -270,8 +270,8 @@ unique_ptr BindApproxQuantile(ClientContext &context, AggregateFun AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) { auto function = GetApproximateQuantileDecimalAggregateFunction(type); function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; + function.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + function.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); return function; } @@ -284,9 +284,9 @@ unique_ptr BindApproxQuantileDecimal(ClientContext &context, Aggre AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { auto fun = GetApproximateQuantileAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.SetBindCallback(BindApproxQuantile); + fun.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + fun.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::FLOAT); return fun; @@ -294,7 +294,6 @@ AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { template struct ApproxQuantileListOperation : public ApproxQuantileOperation { - template static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { if (state.pos == 0) { @@ -342,8 +341,8 @@ AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType using STATE = ApproxQuantileState; using OP = ApproxQuantileListOperation; auto fun = ApproxQuantileListAggregate(type, type); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + fun.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); return fun; } @@ -355,11 +354,11 @@ AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::INTEGER: case LogicalTypeId::DATE: - case LogicalTypeId::TIME: return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::BIGINT: case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIME: return GetTypedApproxQuantileListAggregateFunction(type); case LogicalTypeId::TIME_TZ: // Not binary comparable @@ -391,8 +390,8 @@ AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) { auto function = GetApproxQuantileListAggregateFunction(type); function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; + function.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + function.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); return function; } @@ -405,9 +404,9 @@ unique_ptr BindApproxQuantileDecimalList(ClientContext &context, A AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { auto fun = GetApproxQuantileListAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; + fun.SetBindCallback(BindApproxQuantile); + fun.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + fun.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); // temporarily push an argument so we can bind the actual quantile auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); fun.arguments.push_back(list_of_float); @@ -429,8 +428,8 @@ AggregateFunction GetApproxQuantileDecimal() { // stub function - the actual function is set during bind or deserialize AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproxQuantileDecimalDeserialize; + fun.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + fun.SetDeserializeCallback(ApproxQuantileDecimalDeserialize); return fun; } @@ -439,8 +438,8 @@ AggregateFunction GetApproxQuantileDecimalList() { AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimalList); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproxQuantileDecimalDeserialize; + fun.SetSerializeCallback(ApproximateQuantileBindData::Serialize); + fun.SetDeserializeCallback(ApproxQuantileDecimalDeserialize); return fun; } } // namespace diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp index 9835e44b3..633b7af92 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp @@ -57,7 +57,6 @@ struct QuantileReuseUpdater { }; void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { - // Copy overlapping indices by scanning the previous set and copying down into holes. // We copy instead of leaving gaps in case there are fewer values in the current frame. FrameSet prev_set(prevs); @@ -268,11 +267,11 @@ AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const Logical using OP = MedianAbsoluteDeviationOperation; auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.bind = BindMAD; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetBindCallback(BindMAD); + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; + fun.SetWindowCallback(OP::template Window); + fun.SetWindowInitCallback(OP::template WindowInit); #endif return fun; } @@ -317,7 +316,7 @@ AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const Logi AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + result.SetFallible(); return result; } @@ -325,7 +324,7 @@ unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &conte vector> &arguments) { function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); function.name = "mad"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return BindMAD(context, function, arguments); } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp index dc09dd32f..15b7dd18f 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp @@ -408,7 +408,7 @@ AggregateFunction GetFallbackModeFunction(const LogicalType &type) { AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr); - aggr.destructor = AggregateFunction::StateDestroy; + aggr.SetStateDestructorCallback(AggregateFunction::StateDestroy); return aggr; } @@ -419,7 +419,7 @@ AggregateFunction GetTypedModeFunction(const LogicalType &type) { auto func = AggregateFunction::UnaryAggregateDestructor( type, type); - func.window = OP::template Window; + func.SetWindowCallback(OP::template Window); return func; } @@ -518,7 +518,7 @@ AggregateFunction GetTypedEntropyFunction(const LogicalType &type) { auto func = AggregateFunction::UnaryAggregateDestructor( type, LogicalType::DOUBLE); - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; } @@ -529,8 +529,8 @@ AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) { AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, nullptr); - func.destructor = AggregateFunction::StateDestroy; - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + func.SetStateDestructorCallback(AggregateFunction::StateDestroy); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp index 5009e9669..0578320ac 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp @@ -383,8 +383,8 @@ struct ScalarDiscreteQuantile { auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); #ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::Window; - fun.window_init = OP::WindowInit; + fun.SetWindowCallback(OP::Window); + fun.SetWindowInitCallback(OP::WindowInit); #endif return fun; } @@ -420,10 +420,10 @@ struct ListDiscreteQuantile { using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileListAggregate(type, type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; + fun.SetWindowCallback(OP::template Window); + fun.SetWindowInitCallback(OP::template WindowInit); #endif return fun; } @@ -513,10 +513,10 @@ struct ScalarContinuousQuantile { auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; + fun.SetWindowCallback(OP::template Window); + fun.SetWindowInitCallback(OP::template WindowInit); #endif return fun; } @@ -528,10 +528,10 @@ struct ListContinuousQuantile { using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileListAggregate(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; + fun.SetWindowCallback(OP::template Window); + fun.SetWindowInitCallback(OP::template WindowInit); #endif return fun; } @@ -639,8 +639,8 @@ struct MedianFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = CanInterpolate(type) ? GetContinuousQuantile(type) : GetDiscreteQuantile(type); fun.name = "median"; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(Deserialize); return fun; } @@ -663,12 +663,12 @@ struct DiscreteQuantileListFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetDiscreteQuantileList(type); fun.name = "quantile_disc"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; + fun.SetBindCallback(Bind); + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(Deserialize); // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::LIST(LogicalType::DOUBLE)); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } @@ -691,12 +691,12 @@ struct DiscreteQuantileFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetDiscreteQuantile(type); fun.name = "quantile_disc"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; + fun.SetBindCallback(Bind); + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(Deserialize); // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } @@ -724,12 +724,12 @@ struct ContinuousQuantileFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetContinuousQuantile(type); fun.name = "quantile_cont"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; + fun.SetBindCallback(Bind); + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(Deserialize); // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } @@ -753,13 +753,13 @@ struct ContinuousQuantileListFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetContinuousQuantileList(type); fun.name = "quantile_cont"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; + fun.SetBindCallback(Bind); + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(Deserialize); // temporarily push an argument so we can bind the actual quantile auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } @@ -786,9 +786,9 @@ AggregateFunction EmptyQuantileFunction(LogicalType input, const LogicalType &re if (extra_arg.id() != LogicalTypeId::INVALID) { fun.arguments.push_back(extra_arg); } - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = OP::Deserialize; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetSerializeCallback(QuantileBindData::Serialize); + fun.SetDeserializeCallback(OP::Deserialize); + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp index 583a9f55e..5e2886340 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -364,16 +364,16 @@ unique_ptr BindReservoirQuantileDecimal(ClientContext &context, Ag function = GetReservoirQuantileAggregateFunction(arguments[0]->return_type.InternalType()); auto bind_data = BindReservoirQuantile(context, function, arguments); function.name = "reservoir_quantile"; - function.serialize = ReservoirQuantileBindData::Serialize; - function.deserialize = ReservoirQuantileBindData::Deserialize; + function.SetSerializeCallback(ReservoirQuantileBindData::Serialize); + function.SetDeserializeCallback(ReservoirQuantileBindData::Deserialize); return bind_data; } AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { auto fun = GetReservoirQuantileAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.SetBindCallback(BindReservoirQuantile); + fun.SetSerializeCallback(ReservoirQuantileBindData::Serialize); + fun.SetDeserializeCallback(ReservoirQuantileBindData::Deserialize); // temporarily push an argument so we can bind the actual quantile fun.arguments.emplace_back(LogicalType::DOUBLE); return fun; @@ -381,9 +381,9 @@ AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { auto fun = GetReservoirQuantileListAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.SetBindCallback(BindReservoirQuantile); + fun.SetSerializeCallback(ReservoirQuantileBindData::Serialize); + fun.SetDeserializeCallback(ReservoirQuantileBindData::Deserialize); // temporarily push an argument so we can bind the actual quantile auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); fun.arguments.push_back(list_of_double); @@ -410,8 +410,8 @@ void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const vector const LogicalType &return_value) { AggregateFunction fun(arguments, return_value, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, BindReservoirQuantileDecimal); - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; + fun.SetSerializeCallback(ReservoirQuantileBindData::Serialize); + fun.SetDeserializeCallback(ReservoirQuantileBindData::Deserialize); set.AddFunction(fun); fun.arguments.emplace_back(LogicalType::INTEGER); diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp index e1af92578..790c60d17 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp @@ -61,7 +61,6 @@ struct StringMapType { template void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); auto &input = inputs[0]; @@ -209,14 +208,13 @@ AggregateFunction GetHistogramFunction(const LogicalType &type) { template unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, vector> &arguments) { - D_ASSERT(arguments.size() == 1); if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } function = GetHistogramFunction(arguments[0]->return_type); - return make_uniq(function.return_type); + return make_uniq(function.GetReturnType()); } } // namespace diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 5771e14eb..92916c71b 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -47,7 +47,6 @@ struct ListFunction { void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); auto &input = inputs[0]; RecursiveUnifiedVectorFormat input_data; @@ -75,7 +74,6 @@ void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputD auto combined_ptr = FlatVector::GetData(combined); for (idx_t i = 0; i < count; i++) { - auto &state = *states_ptr[states_data.sel->get_index(i)]; if (state.linked_list.total_capacity == 0) { // NULL, no need to append @@ -98,7 +96,6 @@ void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputD void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - UnifiedVectorFormat states_data; states_vector.ToUnifiedFormat(count, states_data); auto states = UnifiedVectorFormat::GetData(states_data); @@ -132,7 +129,6 @@ void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Ve ListVector::Reserve(result, total_len); auto &result_child = ListVector::GetEntry(result); for (idx_t i = 0; i < count; i++) { - auto &state = *states[states_data.sel->get_index(i)]; const auto rid = i + offset; if (state.linked_list.total_capacity == 0) { @@ -147,7 +143,6 @@ void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Ve } void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, idx_t count) { - // Can we use destructive combining? if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) { ListAbsorbFunction(states_vector, combined, aggr_input_data, count); @@ -182,9 +177,8 @@ void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInput unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, vector> &arguments) { - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return make_uniq(function.return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); + return make_uniq(function.GetReturnType()); } } // namespace diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp index 9215fcfb8..89962af8b 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp @@ -11,7 +11,7 @@ AggregateFunction RegrCountFun::GetFunction() { auto regr_count = AggregateFunction::BinaryAggregate( LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); regr_count.name = "regr_count"; - regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + regr_count.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return regr_count; } diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp index a8ba52658..68a34efeb 100644 --- a/src/duckdb/extension/core_functions/function_list.cpp +++ b/src/duckdb/extension/core_functions/function_list.cpp @@ -73,8 +73,10 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullFun), + DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullsLastFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), @@ -156,7 +158,7 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(DayOfWeekFun), DUCKDB_SCALAR_FUNCTION_SET(DayOfYearFun), DUCKDB_SCALAR_FUNCTION_SET(DecadeFun), - DUCKDB_SCALAR_FUNCTION(DecodeFun), + DUCKDB_SCALAR_FUNCTION_SET(DecodeFun), DUCKDB_SCALAR_FUNCTION(DegreesFun), DUCKDB_SCALAR_FUNCTION_ALIAS(Editdist3Fun), DUCKDB_SCALAR_FUNCTION_ALIAS(ElementAtFun), @@ -335,6 +337,7 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_SCALAR_FUNCTION(SinFun), DUCKDB_SCALAR_FUNCTION(SinhFun), DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), + DUCKDB_SCALAR_FUNCTION(SleepMsFun), DUCKDB_SCALAR_FUNCTION(SqrtFun), DUCKDB_SCALAR_FUNCTION_ALIAS(StartsWithFun), DUCKDB_SCALAR_FUNCTION(StatsFun), @@ -344,7 +347,9 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun), DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun), DUCKDB_SCALAR_FUNCTION(StructInsertFun), + DUCKDB_SCALAR_FUNCTION(StructKeysFun), DUCKDB_SCALAR_FUNCTION(StructUpdateFun), + DUCKDB_SCALAR_FUNCTION(StructValuesFun), DUCKDB_AGGREGATE_FUNCTION_SET(SumFun), DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp index b2626ee27..e626e117c 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp @@ -9,7 +9,8 @@ #pragma once #include "duckdb/function/aggregate_function.hpp" -#include +#include +#include namespace duckdb { diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp index 39bc9459c..4add0a00d 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp @@ -57,6 +57,16 @@ struct ArgMinNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMinNullsLastFun { + static constexpr const char *Name = "arg_min_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct ArgMaxFun { static constexpr const char *Name = "arg_max"; static constexpr const char *Parameters = "arg,val"; @@ -89,6 +99,16 @@ struct ArgMaxNullFun { static AggregateFunctionSet GetFunctions(); }; +struct ArgMaxNullsLastFun { + static constexpr const char *Name = "arg_max_nulls_last"; + static constexpr const char *Parameters = "arg,val,N"; + static constexpr const char *Description = "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row."; + static constexpr const char *Example = "arg_min_null_val(A, B, N)"; + static constexpr const char *Categories = ""; + + static AggregateFunctionSet GetFunctions(); +}; + struct BitAndFun { static constexpr const char *Name = "bit_and"; static constexpr const char *Parameters = "arg"; diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp index 2c796b2e1..c36d0d4fc 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp @@ -300,7 +300,6 @@ struct QuantileIncluded { }; struct QuantileSortTree { - unique_ptr index_tree; QuantileSortTree(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition) { diff --git a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp index dd6e29153..ddcdb92b6 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp @@ -13,7 +13,6 @@ struct InnerProductOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE result = 0; auto lhs_ptr = lhs_data; @@ -43,7 +42,6 @@ struct CosineSimilarityOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE distance = 0; TYPE norm_l = 0; TYPE norm_r = 0; @@ -78,7 +76,6 @@ struct DistanceSquaredOp { template static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - TYPE distance = 0; auto l_ptr = lhs_data; diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp index 0c036c0ba..dc019c1c7 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp @@ -17,12 +17,12 @@ namespace duckdb { struct DecodeFun { static constexpr const char *Name = "decode"; - static constexpr const char *Parameters = "blob"; - static constexpr const char *Description = "Converts `blob` to `VARCHAR`. Fails if `blob` is not valid UTF-8."; - static constexpr const char *Example = "decode('\\xC3\\xBC'::BLOB)"; + static constexpr const char *Parameters = "blob,varchar"; + static constexpr const char *Description = "Converts `blob` to `VARCHAR`. Invalid UTF-8 is handled based on the error behavior argument. Can be 'strict' (default, fail), 'replace' to replace invalid characters with '?', or 'ignore' to skip invalid characters."; + static constexpr const char *Example = "decode('\\xC3\\xBC'::BLOB)\002decode('\\xA0'::BLOB, 'replace')\002decode('\\xA0'::BLOB, 'ignore')"; static constexpr const char *Categories = "blob"; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct EncodeFun { diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp index 65368b2a9..c1d0f9281 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp @@ -25,4 +25,14 @@ struct VectorTypeFun { static ScalarFunction GetFunction(); }; +struct SleepMsFun { + static constexpr const char *Name = "sleep_ms"; + static constexpr const char *Parameters = "milliseconds"; + static constexpr const char *Description = "Sleeps for the specified number of milliseconds and returns NULL"; + static constexpr const char *Example = "sleep_ms(100)"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp index 86c3188fe..349ba4d0a 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp @@ -35,4 +35,24 @@ struct StructUpdateFun { static ScalarFunction GetFunction(); }; +struct StructKeysFun { + static constexpr const char *Name = "struct_keys"; + static constexpr const char *Parameters = "struct"; + static constexpr const char *Description = "Returns the field names of a STRUCT as a list"; + static constexpr const char *Example = "struct_keys({'a': 1, 'b': 2})"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct StructValuesFun { + static constexpr const char *Name = "struct_values"; + static constexpr const char *Parameters = "struct"; + static constexpr const char *Description = "Returns the field values of a STRUCT as an UnnamedStruct"; + static constexpr const char *Example = "struct_values({'a': 1, 'b': 'world'})"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp index f1aa80af7..89356921c 100644 --- a/src/duckdb/extension/core_functions/lambda_functions.cpp +++ b/src/duckdb/extension/core_functions/lambda_functions.cpp @@ -18,7 +18,6 @@ struct LambdaExecuteInfo { LambdaExecuteInfo(ClientContext &context, const Expression &lambda_expr, const DataChunk &args, const bool has_index, const Vector &child_vector) : has_index(has_index) { - expr_executor = make_uniq(context, lambda_expr); // get the input types for the input chunk @@ -103,7 +102,6 @@ struct ListFilterFunctor { //! Uses the lambda vector to filter the incoming list and to append the filtered list to the result vector static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *result_entries, ListFilterInfo &info, LambdaExecuteInfo &execute_info) { - idx_t count = 0; SelectionVector sel(elem_cnt); UnifiedVectorFormat lambda_data; @@ -184,7 +182,6 @@ LambdaFunctions::GetMutableColumnInfo(vector &data) static void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::ColumnInfo &column_info, const vector &column_infos, const Vector &index_vector, LambdaExecuteInfo &info) { - info.input_chunk.SetCardinality(elem_cnt); info.lambda_chunk.SetCardinality(elem_cnt); @@ -203,7 +200,6 @@ static void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::Colum // (slice and) reference the other columns vector slices; for (idx_t i = 0; i < column_infos.size(); i++) { - if (column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { // only reference constant vectorsl info.input_chunk.data[i + slice_offset].Reference(column_infos[i].vector); @@ -273,7 +269,6 @@ LogicalType LambdaFunctions::BindBinaryChildren(const vector &funct template static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { - bool result_is_null = false; LambdaFunctions::LambdaInfo info(args, state, result, result_is_null); if (result_is_null) { @@ -302,7 +297,6 @@ static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &resul idx_t elem_cnt = 0; idx_t offset = 0; for (idx_t row_idx = 0; row_idx < info.row_count; row_idx++) { - auto list_idx = info.list_column_format.sel->get_index(row_idx); const auto &list_entry = info.list_entries[list_idx]; @@ -322,10 +316,8 @@ static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &resul // iterate the elements of the current list and create the corresponding selection vectors for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // reached STANDARD_VECTOR_SIZE elements if (elem_cnt == STANDARD_VECTOR_SIZE) { - execute_info.lambda_chunk.Reset(); ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); auto &lambda_vector = execute_info.lambda_chunk.data[0]; @@ -368,8 +360,8 @@ unique_ptr LambdaFunctions::ListLambdaPrepareBind(vectorreturn_type.id() == LogicalTypeId::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type, nullptr); + bound_function.SetReturnType(LogicalType::SQLNULL); + return make_uniq(bound_function.GetReturnType(), nullptr); } // prepared statements if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { @@ -393,7 +385,7 @@ unique_ptr LambdaFunctions::ListLambdaBind(ClientContext &context, auto &bound_lambda_expr = arguments[1]->Cast(); auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); - return make_uniq(bound_function.return_type, std::move(lambda_expr), has_index); + return make_uniq(bound_function.GetReturnType(), std::move(lambda_expr), has_index); } void LambdaFunctions::ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { diff --git a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp index ecc1ce97f..a5d7067f1 100644 --- a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp @@ -6,14 +6,13 @@ namespace duckdb { static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - const auto &lhs_type = arguments[0]->return_type; const auto &rhs_type = arguments[1]->return_type; if (lhs_type.IsUnknown() && rhs_type.IsUnknown()) { bound_function.arguments[0] = rhs_type; bound_function.arguments[1] = lhs_type; - bound_function.return_type = LogicalType::UNKNOWN; + bound_function.SetReturnType(LogicalType::UNKNOWN); return nullptr; } @@ -212,11 +211,11 @@ static void AddArrayFoldFunction(ScalarFunctionSet &set, const LogicalType &type const auto array = LogicalType::ARRAY(type, optional_idx()); if (type.id() == LogicalTypeId::FLOAT) { ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); } else if (type.id() == LogicalTypeId::DOUBLE) { ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); } else { throw NotImplementedException("Array function not implemented for type %s", type.ToString()); @@ -273,7 +272,7 @@ ScalarFunctionSet ArrayCrossProductFun::GetFunctions() { set.AddFunction( ScalarFunction({double_array, double_array}, double_array, ArrayFixedCombine)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } diff --git a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp index ec8500a87..78e025c2b 100644 --- a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp @@ -62,8 +62,8 @@ unique_ptr ArrayValueBind(ClientContext &context, ScalarFunction & // this is more for completeness reasons bound_function.varargs = child_type; - bound_function.return_type = LogicalType::ARRAY(child_type, arguments.size()); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::ARRAY(child_type, arguments.size())); + return make_uniq(bound_function.GetReturnType()); } unique_ptr ArrayValueStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -74,6 +74,7 @@ unique_ptr ArrayValueStats(ClientContext &context, FunctionStati for (idx_t i = 0; i < child_stats.size(); i++) { list_child_stats.Merge(child_stats[i]); } + list_stats.SetHasNoNullFast(); return list_stats.ToUnique(); } @@ -84,7 +85,7 @@ ScalarFunction ArrayValueFun::GetFunction() { ScalarFunction fun("array_value", {}, LogicalTypeId::ARRAY, ArrayValueFunction, ArrayValueBind, nullptr, ArrayValueStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp index fdd499e22..b3e26a7f1 100644 --- a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp +++ b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp @@ -47,7 +47,7 @@ ScalarFunctionSet BitStringFun::GetFunctions() { bitstring.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); for (auto &func : bitstring.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return bitstring; } @@ -72,7 +72,7 @@ struct GetBitOperator { ScalarFunction GetBitFun::GetFunction() { ScalarFunction func({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } @@ -100,7 +100,7 @@ static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &res ScalarFunction SetBitFun::GetFunction() { ScalarFunction function({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, SetBitOperation); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp index d2c372114..77ae51731 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp @@ -43,7 +43,7 @@ ScalarFunction ToBase64Fun::GetFunction() { ScalarFunction FromBase64Fun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp index b9bfa986a..48dda08c2 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp @@ -1,5 +1,6 @@ #include "core_functions/scalar/blob_functions.hpp" #include "utf8proc_wrapper.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/common/exception/conversion_exception.hpp" namespace duckdb { @@ -12,22 +13,87 @@ void EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { result.Reinterpret(args.data[0]); } -struct BlobDecodeOperator { +enum class DecodeErrorBehavior : uint8_t { + STRICT = 1, // raise an error + REPLACE = 2, // replace invalid characters with '?' + IGNORE = 3 // ignore invalid characters (remove from string) +}; + +DecodeErrorBehavior GetDecodeErrorBehavior(const string_t &specifier_p) { + auto size = specifier_p.GetSize(); + auto data = specifier_p.GetData(); + if (StringUtil::CIEquals(data, size, "strict", 6)) { + return DecodeErrorBehavior::STRICT; + } else if (StringUtil::CIEquals(data, size, "replace", 7)) { + return DecodeErrorBehavior::REPLACE; + } else if (StringUtil::CIEquals(data, size, "ignore", 6)) { + return DecodeErrorBehavior::IGNORE; + } else { + throw ConversionException("decode error behavior specifier \"%s\" not recognized", specifier_p.GetString()); + } +} + +struct UnaryBlobDecodeOperator { template static RESULT_TYPE Operation(INPUT_TYPE input) { auto input_data = input.GetData(); auto input_length = input.GetSize(); if (Utf8Proc::Analyze(input_data, input_length) == UnicodeType::INVALID) { throw ConversionException( - "Failure in decode: could not convert blob to UTF8 string, the blob contained invalid UTF8 characters"); + "Failure in decode: could not convert blob to UTF8 string, the blob " + "contained invalid UTF8 characters. \n" + "Use try(decode(BLOB)) to return NULL and continue instead of returning an error. " + "Specify decode(BLOB, 'replace') to replace invalid characters with '?'. " + "Specify decode(BLOB, 'ignore') to remove invalid characters when encountered."); } return input; } }; -void DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { +void UnaryDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // decode is also a nop cast, but requires verification if the provided string is actually + UnaryExecutor::Execute(args.data[0], result, args.size()); + StringVector::AddHeapReference(result, args.data[0]); +} + +void BinaryDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::Execute(args.data[0], result, args.size()); + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, string_t error_option) { + auto input_data = input.GetDataWriteable(); + auto input_length = input.GetSize(); + + if (Utf8Proc::Analyze(input_data, input_length) != UnicodeType::INVALID) { + return input; + } + auto const error_behavior = GetDecodeErrorBehavior(error_option); + + switch (error_behavior) { + case DecodeErrorBehavior::REPLACE: + Utf8Proc::MakeValid(input_data, input_length); + return input; + + case DecodeErrorBehavior::IGNORE: { + auto new_str = Utf8Proc::RemoveInvalid(input_data, input_length); + auto target = StringVector::EmptyString(result, new_str.size()); + auto output = target.GetDataWriteable(); + memcpy(output, new_str.data(), new_str.size()); + target.Finalize(); + return target; + } + + case DecodeErrorBehavior::STRICT: + throw ConversionException( + "Failure in decode: could not convert blob to UTF8 string, the blob " + "contained invalid UTF8 characters. \n" + "Use try(decode(BLOB)) to return NULL and continue instead of returning an error. " + "Specify decode(BLOB, 'replace') to replace invalid characters with '?'. " + "Specify decode(BLOB, 'ignore') to remove invalid characters when encountered."); + + default: + throw InternalException("Unimplemented decode error behavior"); + } + }); StringVector::AddHeapReference(result, args.data[0]); } @@ -37,10 +103,20 @@ ScalarFunction EncodeFun::GetFunction() { return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, EncodeFunction); } -ScalarFunction DecodeFun::GetFunction() { - ScalarFunction function({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); - BaseScalarFunction::SetReturnsError(function); - return function; +ScalarFunctionSet DecodeFun::GetFunctions() { + ScalarFunctionSet decode("decode"); + + ScalarFunction unary_function({LogicalType::BLOB}, LogicalType::VARCHAR, UnaryDecodeFunction); + ScalarFunction binary_function({LogicalType::BLOB, LogicalType::VARCHAR}, LogicalType::VARCHAR, + BinaryDecodeFunction); + + unary_function.SetFallible(); + binary_function.SetFallible(); + + decode.AddFunction(unary_function); + decode.AddFunction(binary_function); + + return decode; } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/current.cpp b/src/duckdb/extension/core_functions/scalar/date/current.cpp index aa041f627..bf928618d 100644 --- a/src/duckdb/extension/core_functions/scalar/date/current.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/current.cpp @@ -23,7 +23,7 @@ static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, V ScalarFunction GetCurrentTimestampFun::GetFunction() { ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); - current_timestamp.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_timestamp.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_timestamp; } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp index 7ced59dcb..3633f1f53 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -1772,18 +1772,20 @@ unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bo arguments.erase(arguments.begin()); bound_function.arguments.erase(bound_function.arguments.begin()); bound_function.name = "julian"; - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); switch (arguments[0]->return_type.id()) { case LogicalType::TIMESTAMP: case LogicalType::TIMESTAMP_S: case LogicalType::TIMESTAMP_MS: case LogicalType::TIMESTAMP_NS: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + bound_function.SetFunctionCallback( + DatePart::UnaryFunction); + bound_function.SetStatisticsCallback( + DatePart::JulianDayOperator::template PropagateStatistics); break; case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::JulianDayOperator::template PropagateStatistics); break; default: throw BinderException("%s can only take DATE or TIMESTAMP arguments", bound_function.name); @@ -1793,34 +1795,34 @@ unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bo arguments.erase(arguments.begin()); bound_function.arguments.erase(bound_function.arguments.begin()); bound_function.name = "epoch"; - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); switch (arguments[0]->return_type.id()) { case LogicalType::TIMESTAMP: case LogicalType::TIMESTAMP_S: case LogicalType::TIMESTAMP_MS: case LogicalType::TIMESTAMP_NS: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; case LogicalType::INTERVAL: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; case LogicalType::TIME: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; case LogicalType::TIME_NS: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; case LogicalType::TIME_TZ: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; + bound_function.SetFunctionCallback(DatePart::UnaryFunction); + bound_function.SetStatisticsCallback(DatePart::EpochOperator::template PropagateStatistics); break; default: throw BinderException("%s can only take temporal arguments", bound_function.name); @@ -1844,7 +1846,7 @@ ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar nullptr, ts_stats, DATE_CACHE)); operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; } @@ -1974,8 +1976,8 @@ struct StructDatePart { } Function::EraseArgument(bound_function, arguments, 0); - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type, part_codes); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); + return make_uniq(bound_function.GetReturnType(), part_codes); } template @@ -2122,8 +2124,8 @@ struct StructDatePart { auto part_type = LogicalType::LIST(LogicalType::VARCHAR); auto result_type = LogicalType::STRUCT({}); ScalarFunction result({part_type, temporal_type}, result_type, Function, Bind); - result.serialize = SerializeFunction; - result.deserialize = DeserializeFunction; + result.SetSerializeCallback(SerializeFunction); + result.SetDeserializeCallback(DeserializeFunction); return result; } }; @@ -2168,7 +2170,7 @@ ScalarFunctionSet QuarterFun::GetFunctions() { ScalarFunctionSet DayOfWeekFun::GetFunctions() { auto set = GetDatePartFunction(); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } @@ -2203,7 +2205,7 @@ ScalarFunctionSet TimezoneFun::GetFunctions() { operator_set.AddFunction(function); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; @@ -2408,7 +2410,7 @@ ScalarFunctionSet DatePartFun::GetFunctions() { date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME_TZ)); for (auto &func : date_part.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return date_part; diff --git a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp index dab1c8231..6e3d450d7 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp @@ -37,7 +37,6 @@ struct DateSub { struct MonthOperator { template static inline TR Operation(TA start_ts, TB end_ts) { - if (start_ts > end_ts) { return -MonthOperator::Operation(end_ts, start_ts); } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp index 819efbac4..912af31f2 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp @@ -693,25 +693,25 @@ unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &b case DatePartSpecifier::JULIAN_DAY: switch (bound_function.arguments[1].id()) { case LogicalType::TIMESTAMP: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); + bound_function.SetFunctionCallback(DateTruncFunction); + bound_function.SetStatisticsCallback(DateTruncStats(part_code)); break; case LogicalType::DATE: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); + bound_function.SetFunctionCallback(DateTruncFunction); + bound_function.SetStatisticsCallback(DateTruncStats(part_code)); break; default: throw NotImplementedException("Temporal argument type for DATETRUNC"); } - bound_function.return_type = LogicalType::DATE; + bound_function.SetReturnType(LogicalType::DATE); break; default: switch (bound_function.arguments[1].id()) { case LogicalType::TIMESTAMP: - bound_function.statistics = DateTruncStats(part_code); + bound_function.SetStatisticsCallback(DateTruncStats(part_code)); break; case LogicalType::DATE: - bound_function.statistics = DateTruncStats(part_code); + bound_function.SetStatisticsCallback(DateTruncStats(part_code)); break; default: throw NotImplementedException("Temporal argument type for DATETRUNC"); @@ -733,7 +733,7 @@ ScalarFunctionSet DateTruncFun::GetFunctions() { date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, DateTruncFunction)); for (auto &func : date_trunc.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return date_trunc; } diff --git a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp index 189d2a229..d7f1eaf99 100644 --- a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp @@ -65,7 +65,6 @@ void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &res struct MakeTimeOperator { template static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { - auto hh_32 = Cast::Operation(hh); auto mm_32 = Cast::Operation(mm); // Have to check this separately because safe casting of DOUBLE => INT32 can round. @@ -149,7 +148,7 @@ ScalarFunctionSet MakeDateFun::GetFunctions() { make_date.AddFunction( ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); for (auto &func : make_date.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return make_date; } @@ -157,7 +156,7 @@ ScalarFunctionSet MakeDateFun::GetFunctions() { ScalarFunction MakeTimeFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, ExecuteMakeTime); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -170,7 +169,7 @@ ScalarFunctionSet MakeTimestampFun::GetFunctions() { ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return operator_set; } diff --git a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp index 6427a55f5..e767282d3 100644 --- a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp @@ -16,7 +16,6 @@ namespace duckdb { namespace { struct TimeBucket { - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility // There are 10959 days between 1970-01-01 and 2000-01-03 constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; @@ -369,7 +368,7 @@ ScalarFunctionSet TimeBucketFun::GetFunctions() { time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, TimeBucketOriginFunction)); for (auto &func : time_bucket.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return time_bucket; } diff --git a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp index d8c0f58e0..8ad21e543 100644 --- a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp @@ -183,7 +183,7 @@ ScalarFunctionSet GetIntegerIntervalFunctions() { function_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction)); for (auto &func : function_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return function_set; } @@ -225,35 +225,35 @@ ScalarFunctionSet ToDaysFun::GetFunctions() { ScalarFunction ToHoursFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMinutesFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToSecondsFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMillisecondsFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } ScalarFunction ToMicrosecondsFun::GetFunction() { ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp b/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp new file mode 100644 index 000000000..b0cdc4646 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp @@ -0,0 +1,41 @@ +#include "core_functions/scalar/debug_functions.hpp" + +#include "duckdb/common/vector_operations/generic_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +#include "duckdb/common/thread.hpp" + +namespace duckdb { + +struct NullResultType { + using STRUCT_STATE = PrimitiveTypeState; + + static void AssignResult(Vector &result, idx_t i, NullResultType) { + FlatVector::SetNull(result, i, true); + } +}; + +static void SleepFunction(DataChunk &input, ExpressionState &state, Vector &result) { + input.Flatten(); + GenericExecutor::ExecuteUnary, NullResultType>(input.data[0], result, input.size(), + [](PrimitiveType input) { + // Sleep for the specified number of + // milliseconds (clamp negative values to + // 0) + int64_t sleep_ms = input.val; + if (sleep_ms < 0) { + sleep_ms = 0; + } + ThreadUtil::SleepMs(sleep_ms); + return NullResultType(); + }); +} + +ScalarFunction SleepMsFun::GetFunction() { + auto sleep_fun = ScalarFunction({LogicalType::BIGINT}, LogicalType::SQLNULL, SleepFunction, nullptr); + sleep_fun.stability = FunctionStability::VOLATILE; + sleep_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return sleep_fun; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp index 627d7ac28..73545544f 100644 --- a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp @@ -17,7 +17,7 @@ ScalarFunction VectorTypeFun::GetFunction() { {LogicalType::ANY}, // argument list LogicalType::VARCHAR, // return type VectorTypeFunction); - vector_type_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + vector_type_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return vector_type_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp index be3c5c03b..8de43097e 100644 --- a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp @@ -90,16 +90,16 @@ static unique_ptr BindEnumCodeFunction(ClientContext &context, Sca auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); switch (phy_type) { case PhysicalType::UINT8: - bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UTINYINT)); break; case PhysicalType::UINT16: - bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::USMALLINT)); break; case PhysicalType::UINT32: - bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UINTEGER)); break; case PhysicalType::UINT64: - bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); + bound_function.SetReturnType(LogicalType(LogicalTypeId::UBIGINT)); break; default: throw InternalException("Unsupported Enum Internal Type"); @@ -131,33 +131,33 @@ static unique_ptr BindEnumRangeBoundaryFunction(ClientContext &con ScalarFunction EnumFirstFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumLastFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumCodeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumRangeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } ScalarFunction EnumRangeBoundaryFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp index 4edadcaaf..222510cb8 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp @@ -11,7 +11,7 @@ static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &resul ScalarFunction AliasFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp index ffaceaf3d..753927b27 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp @@ -422,7 +422,7 @@ unique_ptr BindEquiWidthFunction(ClientContext &, ScalarFunction & child_type = arguments[1]->return_type; break; } - bound_function.return_type = LogicalType::LIST(child_type); + bound_function.SetReturnType(LogicalType::LIST(child_type)); return nullptr; } @@ -478,7 +478,7 @@ void EquiWidthBinSerialize(Serializer &, const optional_ptr, const } unique_ptr EquiWidthBinDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.return_type = deserializer.Get(); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -502,9 +502,9 @@ ScalarFunctionSet EquiWidthBinsFun::GetFunctions() { LogicalType::BIGINT, LogicalType::BOOLEAN}, LogicalType::LIST(LogicalType::ANY), UnsupportedEquiWidth, BindEquiWidthFunction)); for (auto &function : functions.functions) { - function.serialize = EquiWidthBinSerialize; - function.deserialize = EquiWidthBinDeserialize; - BaseScalarFunction::SetReturnsError(function); + function.SetSerializeCallback(EquiWidthBinSerialize); + function.SetDeserializeCallback(EquiWidthBinDeserialize); + function.SetFallible(); } return functions; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp index 1f28c8da8..ca212a29d 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp @@ -36,8 +36,8 @@ unique_ptr BindCanCastImplicitlyExpression(FunctionBindExpressionInp ScalarFunction CanCastImplicitlyFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, CanCastImplicitlyFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindCanCastImplicitlyExpression; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetBindExpressionCallback(BindCanCastImplicitlyExpression); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp b/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp index 4b87705d7..51ff98ace 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/cast_to_type.cpp @@ -24,8 +24,8 @@ unique_ptr BindCastToTypeFunction(FunctionBindExpressionInput &input } // namespace ScalarFunction CastToTypeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, CastToTypeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindCastToTypeFunction; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetBindExpressionCallback(BindCastToTypeFunction); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp index 4464f0544..600cad564 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -5,6 +5,8 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/exception/parser_exception.hpp" + namespace duckdb { namespace { @@ -33,7 +35,6 @@ void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &res unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto &key_child = arguments[0]; if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); @@ -53,13 +54,10 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti if (!context.TryGetCurrentSetting(key, val)) { auto extension_name = Catalog::AutoloadExtensionByConfigName(context, key); // If autoloader didn't throw, the config is now available - if (!context.TryGetCurrentSetting(key, val)) { - throw InternalException("Extension %s did not provide the '%s' config setting", - extension_name.ToStdString(), key); - } + context.TryGetCurrentSetting(key, val); } - bound_function.return_type = val.type(); + bound_function.SetReturnType(val.type()); return make_uniq(val); } @@ -67,7 +65,7 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti ScalarFunction CurrentSettingFun::GetFunction() { auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp index 184919447..a829d67fe 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp @@ -12,7 +12,7 @@ static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result ScalarFunction HashFun::GetFunction() { auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); hash_fun.varargs = LogicalType::ANY; - hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + hash_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return hash_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/least.cpp b/src/duckdb/extension/core_functions/scalar/generic/least.cpp index 519350c1b..a38f0f29e 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/least.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/least.cpp @@ -203,36 +203,36 @@ unique_ptr BindLeastGreatest(ClientContext &context, ScalarFunctio #ifndef DUCKDB_SMALLER_BINARY case PhysicalType::BOOL: case PhysicalType::INT8: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::INT16: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::INT32: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::INT64: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::INT128: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::DOUBLE: - bound_function.function = LeastGreatestFunction; + bound_function.SetFunctionCallback(LeastGreatestFunction); break; case PhysicalType::VARCHAR: - bound_function.function = LeastGreatestFunction>; + bound_function.SetFunctionCallback(LeastGreatestFunction>); break; #endif default: // fallback with sort keys - bound_function.function = LeastGreatestFunction; - bound_function.init_local_state = LeastGreatestSortKeyInit; + bound_function.SetFunctionCallback(LeastGreatestFunction); + bound_function.SetInitStateCallback(LeastGreatestSortKeyInit); break; } bound_function.arguments[0] = child_type; bound_function.varargs = child_type; - bound_function.return_type = child_type; + bound_function.SetReturnType(child_type); return nullptr; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp b/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp index b6c823a33..64bfc9ffa 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/replace_type.cpp @@ -8,7 +8,7 @@ static void ReplaceTypeFunction(DataChunk &, ExpressionState &, Vector &) { throw InternalException("ReplaceTypeFunction function cannot be executed directly"); } -unique_ptr BindReplaceTypeFunction(FunctionBindExpressionInput &input) { +static unique_ptr BindReplaceTypeFunction(FunctionBindExpressionInput &input) { const auto &from = input.children[1]->return_type; const auto &to = input.children[2]->return_type; if (from.id() == LogicalTypeId::UNKNOWN || to.id() == LogicalTypeId::UNKNOWN) { @@ -26,8 +26,8 @@ unique_ptr BindReplaceTypeFunction(FunctionBindExpressionInput &inpu ScalarFunction ReplaceTypeFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ReplaceTypeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindReplaceTypeFunction; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetBindExpressionCallback(BindReplaceTypeFunction); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp index d6b5f5e13..3bd18ae01 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp @@ -49,8 +49,8 @@ unique_ptr StatsPropagateStats(ClientContext &context, FunctionS ScalarFunction StatsFun::GetFunction() { ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, StatsPropagateStats); - stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - stats.stability = FunctionStability::VOLATILE; + stats.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + stats.SetStability(FunctionStability::VOLATILE); return stats; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp index 5a0b25a6d..ea35972bd 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp @@ -108,19 +108,19 @@ void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { ScalarFunction CurrentQueryFun::GetFunction() { ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); - current_query.stability = FunctionStability::VOLATILE; + current_query.SetStability(FunctionStability::VOLATILE); return current_query; } ScalarFunction CurrentSchemaFun::GetFunction() { ScalarFunction current_schema({}, LogicalType::VARCHAR, CurrentSchemaFunction); - current_schema.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_schema.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_schema; } ScalarFunction CurrentDatabaseFun::GetFunction() { ScalarFunction current_database({}, LogicalType::VARCHAR, CurrentDatabaseFunction); - current_database.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_database.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_database; } @@ -128,20 +128,20 @@ ScalarFunction CurrentSchemasFun::GetFunction() { auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); ScalarFunction current_schemas({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction, CurrentSchemasBind); - current_schemas.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_schemas.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_schemas; } ScalarFunction InSearchPathFun::GetFunction() { ScalarFunction in_search_path({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, InSearchPathFunction); - in_search_path.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + in_search_path.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return in_search_path; } ScalarFunction CurrentTransactionIdFun::GetFunction() { ScalarFunction txid_current({}, LogicalType::UBIGINT, TransactionIdCurrent); - txid_current.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + txid_current.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return txid_current; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp index a5d26ad8c..008df7678 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp @@ -25,8 +25,8 @@ unique_ptr BindTypeOfFunctionExpression(FunctionBindExpressionInput ScalarFunction TypeOfFun::GetFunction() { auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindTypeOfFunctionExpression; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetBindExpressionCallback(BindTypeOfFunctionExpression); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp index 98cbef28a..05124ea9b 100644 --- a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp @@ -161,7 +161,6 @@ template void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { - // check all this nullness early auto str_valid = !ConstantVector::IsNull(str_vector); auto begin_valid = !ConstantVector::IsNull(begin_vector); @@ -404,11 +403,11 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & auto child_type = ArrayType::GetChildType(arguments[0]->return_type); auto target_type = LogicalType::LIST(child_type); arguments[0] = BoundCastExpression::AddCastToType(context, std::move(arguments[0]), target_type); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); } break; case LogicalTypeId::LIST: // The result is the same type - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); break; case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: @@ -421,9 +420,9 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & if (arguments[0]->return_type.IsJSONType()) { // This is needed to avoid producing invalid JSON bound_function.arguments[0] = LogicalType::VARCHAR; - bound_function.return_type = LogicalType::VARCHAR; + bound_function.SetReturnType(LogicalType::VARCHAR); } else { - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); } for (idx_t i = 1; i < 3; i++) { if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { @@ -434,7 +433,7 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & case LogicalTypeId::SQLNULL: case LogicalTypeId::UNKNOWN: bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); break; default: throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); @@ -449,7 +448,7 @@ unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction & bound_function.arguments[2] = LogicalType::BIGINT; } - return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); + return make_uniq(bound_function.GetReturnType(), begin_is_empty, end_is_empty); } } // namespace @@ -457,8 +456,8 @@ ScalarFunctionSet ListSliceFun::GetFunctions() { // the arguments and return types are actually set in the binder function ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, ArraySliceBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetFallible(); ScalarFunctionSet set; set.AddFunction(fun); fun.arguments.push_back(LogicalType::BIGINT); diff --git a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp index 97b3d625f..23cbf8660 100644 --- a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp @@ -10,7 +10,6 @@ namespace duckdb { namespace { void ListFlattenFunction(DataChunk &args, ExpressionState &, Vector &result) { - const auto flat_list_data = FlatVector::GetData(result); auto &flat_list_mask = FlatVector::Validity(result); diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp index 8be7134ab..eb06d4728 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -32,7 +32,7 @@ unique_ptr ListAggregatesInitLocalState(ExpressionState &sta unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return make_uniq(LogicalType::SQLNULL); } @@ -93,10 +93,10 @@ struct StateVector { ~StateVector() { // NOLINT // destroy objects within the aggregate states auto &aggr = aggr_expr->Cast(); - if (aggr.function.destructor) { + if (aggr.function.HasStateDestructorCallback()) { ArenaAllocator allocator(Allocator::DefaultAllocator()); AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.destructor(state_vector, aggr_input_data, count); + aggr.function.GetStateDestructorCallback()(state_vector, aggr_input_data, count); } } @@ -187,7 +187,6 @@ struct UniqueFunctor { auto result_data = FlatVector::GetData(result); for (idx_t i = 0; i < count; i++) { - auto state = states[sdata.sel->get_index(i)]; if (!state->hist) { @@ -223,7 +222,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res allocator.Reset(); AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - D_ASSERT(aggr.function.update); + D_ASSERT(aggr.function.HasStateUpdateCallback()); auto lists_size = ListVector::GetListSize(lists); auto &child_vector = ListVector::GetEntry(lists); @@ -237,7 +236,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res auto list_entries = UnifiedVectorFormat::GetData(lists_data); // state_buffer holds the state for each list of this chunk - idx_t size = aggr.function.state_size(aggr.function); + idx_t size = aggr.function.GetStateSizeCallback()(aggr.function); auto state_buffer = make_unsafe_uniq_array_uninitialized(size * count); // state vector for initialize and finalize @@ -253,11 +252,10 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res idx_t states_idx = 0; for (idx_t i = 0; i < count; i++) { - // initialize the state for this list auto state_ptr = state_buffer.get() + size * i; states[i] = state_ptr; - aggr.function.initialize(aggr.function, states[i]); + aggr.function.GetStateInitCallback()(aggr.function, states[i]); auto lists_index = lists_data.sel->get_index(i); const auto &list_entry = list_entries[lists_index]; @@ -278,7 +276,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res if (states_idx == STANDARD_VECTOR_SIZE) { // update the aggregate state(s) Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); // reset values states_idx = 0; @@ -294,12 +292,12 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res // update the remaining elements of the last list(s) if (states_idx != 0) { Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); + aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); } if (IS_AGGR) { // finalize all the aggregate states - aggr.function.finalize(state_vector.state_vector, aggr_input_data, result, count, 0); + aggr.function.GetStateFinalizeCallback()(state_vector.state_vector, aggr_input_data, result, count, 0); } else { // finalize manually to use the map @@ -390,7 +388,6 @@ template unique_ptr ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, AggregateFunction &aggr_function, vector> &arguments) { - // create the child expression and its type vector> children; auto expr = make_uniq(Value(list_child_type)); @@ -408,7 +405,7 @@ ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_functio bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); if (IS_AGGR) { - bound_function.return_type = bound_aggr_function->function.return_type; + bound_function.SetReturnType(bound_aggr_function->function.GetReturnType()); } // check if the aggregate function consumed all the extra input arguments if (bound_aggr_function->children.size() > 1) { @@ -417,13 +414,12 @@ ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_functio bound_aggr_function->ToString()); } - return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); + return make_uniq(bound_function.GetReturnType(), std::move(bound_aggr_function)); } template unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { @@ -459,7 +455,7 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti if (is_parameter) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return nullptr; } @@ -481,7 +477,7 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti // found a matching function, bind it as an aggregate auto best_function = func.functions.GetFunctionByOffset(best_function_idx.GetIndex()); if (IS_AGGR) { - bound_function.errors = best_function.errors; + bound_function.SetErrorMode(best_function.GetErrorMode()); return ListAggregatesBindFunction(context, bound_function, child_type, best_function, arguments); } @@ -493,7 +489,6 @@ unique_ptr ListAggregatesBind(ClientContext &context, ScalarFuncti unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the name of the aggregate function D_ASSERT(bound_function.arguments.size() >= 2); D_ASSERT(arguments.size() >= 2); @@ -507,11 +502,11 @@ ScalarFunction ListAggregateFun::GetFunction() { auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, ListAggregateFunction, ListAggregateBind, nullptr, nullptr, ListAggregatesInitLocalState); - BaseScalarFunction::SetReturnsError(result); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + result.SetFallible(); + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); result.varargs = LogicalType::ANY; - result.serialize = ListAggregatesBindData::SerializeFunction; - result.deserialize = ListAggregatesBindData::DeserializeFunction; + result.SetSerializeCallback(ListAggregatesBindData::SerializeFunction); + result.SetDeserializeCallback(ListAggregatesBindData::DeserializeFunction); return result; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp index 5c3513b2a..ad0c488a0 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp @@ -88,7 +88,7 @@ ScalarFunctionSet ListDistanceFun::GetFunctions() { AddListFoldFunction(set, type); } for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } @@ -115,7 +115,7 @@ ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { AddListFoldFunction(set, type); } for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return set; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp index 4224fad24..017c611a2 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp @@ -7,7 +7,6 @@ namespace duckdb { static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -25,7 +24,7 @@ static unique_ptr ListFilterBind(ClientContext &context, ScalarFun arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); auto has_index = bound_lambda_expr.parameter_count == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } @@ -39,10 +38,10 @@ ScalarFunction ListFilterFun::GetFunction() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), LambdaFunctions::ListFilterFunction, ListFilterBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListFilterBindLambda; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetSerializeCallback(ListLambdaBindData::Serialize); + fun.SetDeserializeCallback(ListLambdaBindData::Deserialize); + fun.SetBindLambdaCallback(ListFilterBindLambda); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp index 51b4980cd..ff2fd5354 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp @@ -7,7 +7,6 @@ namespace duckdb { static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto &l_vec = args.data[0]; auto &r_vec = args.data[1]; @@ -63,7 +62,6 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul // Use the smaller list to build the set if (r_list.length < l_list.length) { - build_list = r_list; probe_list = l_list; @@ -96,7 +94,6 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul } static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { - const auto &func_expr = state.expr.Cast(); const auto swap = func_expr.function.name == "<@"; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp index 08f64b54e..f3c23e98f 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp @@ -175,7 +175,6 @@ bool ExecuteReduce(const idx_t loops, ReduceExecuteInfo &execute_info, LambdaFun unique_ptr ListReduceBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2 || arguments.size() == 3); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -223,8 +222,8 @@ unique_ptr ListReduceBind(ClientContext &context, ScalarFunction & if (!cast_lambda_expr) { throw BinderException("Could not cast lambda expression to list child type"); } - bound_function.return_type = cast_lambda_expr->return_type; - return make_uniq(bound_function.return_type, std::move(cast_lambda_expr), has_index, + bound_function.SetReturnType(cast_lambda_expr->return_type); + return make_uniq(bound_function.GetReturnType(), std::move(cast_lambda_expr), has_index, has_initial); } @@ -311,10 +310,10 @@ ScalarFunctionSet ListReduceFun::GetFunctions() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::ANY, LambdaFunctions::ListReduceFunction, ListReduceBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListReduceBindLambda; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetSerializeCallback(ListLambdaBindData::Serialize); + fun.SetDeserializeCallback(ListLambdaBindData::Deserialize); + fun.SetBindLambdaCallback(ListReduceBindLambda); ScalarFunctionSet set; set.AddFunction(fun); diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp index 1263500c9..61fab0938 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp @@ -37,7 +37,6 @@ ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_ ClientContext &context_p) : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), is_grade_up(is_grade_up_p), context(context_p) { - // get the vector types types.emplace_back(LogicalType::USMALLINT); types.emplace_back(child_type); @@ -71,7 +70,6 @@ static void SinkDataChunk(const Sort &sort, ExecutionContext &context, OperatorS Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, Vector &payload_vector, bool &data_to_sort, Vector &lists_indices) { - // slice the child vector Vector slice(*child_vector, sel, offset_lists_indices); @@ -256,22 +254,22 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments, OrderType &order, OrderByNullType &null_order) { - LogicalType child_type; if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - child_type = bound_function.return_type; - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); + bound_function.SetReturnType(LogicalType::SQLNULL); + child_type = bound_function.GetReturnType(); + return make_uniq(order, null_order, false, bound_function.GetReturnType(), child_type, + context); } arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); child_type = ListType::GetChildType(arguments[0]->return_type); bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); + return make_uniq(order, null_order, false, bound_function.GetReturnType(), child_type, context); } template @@ -286,7 +284,6 @@ static T GetOrder(ClientContext &context, Expression &expr) { static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - D_ASSERT(!arguments.empty() && arguments.size() <= 3); auto order = OrderType::ORDER_DEFAULT; auto null_order = OrderByNullType::ORDER_DEFAULT; @@ -306,9 +303,9 @@ static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFu arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = LogicalType::LIST(LogicalTypeId::BIGINT); + bound_function.SetReturnType(LogicalType::LIST(LogicalTypeId::BIGINT)); auto child_type = ListType::GetChildType(arguments[0]->return_type); - return make_uniq(order, null_order, true, bound_function.return_type, child_type, context); + return make_uniq(order, null_order, true, bound_function.GetReturnType(), child_type, context); } static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, diff --git a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp index 97e8be006..be4f319a9 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp @@ -7,7 +7,6 @@ namespace duckdb { static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // the list column and the bound lambda expression D_ASSERT(arguments.size() == 2); if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { @@ -17,7 +16,7 @@ static unique_ptr ListTransformBind(ClientContext &context, Scalar arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); auto &bound_lambda_expr = arguments[1]->Cast(); - bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); + bound_function.SetReturnType(LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type)); auto has_index = bound_lambda_expr.parameter_count == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } @@ -31,10 +30,10 @@ ScalarFunction ListTransformFun::GetFunction() { ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), LambdaFunctions::ListTransformFunction, ListTransformBind, nullptr, nullptr); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListTransformBindLambda; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetSerializeCallback(ListLambdaBindData::Serialize); + fun.SetDeserializeCallback(ListLambdaBindData::Deserialize); + fun.SetBindLambdaCallback(ListTransformBindLambda); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp index cec76fe89..fd556b7f0 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp @@ -291,8 +291,8 @@ unique_ptr UnpivotBind(ClientContext &context, ScalarFunction &bou // this is more for completeness reasons bound_function.varargs = child_type; - bound_function.return_type = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::LIST(child_type)); + return make_uniq(bound_function.GetReturnType()); } unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -303,13 +303,13 @@ unique_ptr ListValueStats(ClientContext &context, FunctionStatis for (idx_t i = 0; i < child_stats.size(); i++) { list_child_stats.Merge(child_stats[i]); } + list_stats.SetHasNoNullFast(); return list_stats.ToUnique(); } } // namespace ScalarFunctionSet ListValueFun::GetFunctions() { - ScalarFunctionSet set("list_value"); // Overload for 0 arguments, which returns an empty list. @@ -322,7 +322,7 @@ ScalarFunctionSet ListValueFun::GetFunctions() { ScalarFunction value_fun({element_type}, LogicalType::LIST(element_type), ListValueFunction, nullptr, nullptr, ListValueStats); value_fun.varargs = element_type; - value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + value_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); set.AddFunction(value_fun); return set; @@ -332,7 +332,7 @@ ScalarFunction UnpivotListFun::GetFunction() { ScalarFunction fun("unpivot_list", {}, LogicalTypeId::LIST, ListValueFunction, UnpivotBind, nullptr, ListValueStats); fun.varargs = LogicalTypeId::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/range.cpp b/src/duckdb/extension/core_functions/scalar/list/range.cpp index 494039d41..13281c09a 100644 --- a/src/duckdb/extension/core_functions/scalar/list/range.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/range.cpp @@ -258,7 +258,7 @@ ScalarFunctionSet ListRangeFun::GetFunctions() { LogicalType::LIST(LogicalType::TIMESTAMP), ListRangeFunction)); for (auto &func : range_set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return range_set; } @@ -277,7 +277,7 @@ ScalarFunctionSet GenerateSeriesFun::GetFunctions() { LogicalType::LIST(LogicalType::TIMESTAMP), ListRangeFunction)); for (auto &func : generate_series.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return generate_series; } diff --git a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp index 9c81223e7..9806b5d76 100644 --- a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp @@ -36,14 +36,14 @@ static unique_ptr CardinalityBind(ClientContext &context, ScalarFu throw BinderException("Cardinality can only operate on MAPs"); } - bound_function.return_type = LogicalType::UBIGINT; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::UBIGINT); + return make_uniq(bound_function.GetReturnType()); } ScalarFunction CardinalityFun::GetFunction() { ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::DEFAULT_NULL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp index ab9bea1bb..8b1e86a13 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map.cpp @@ -38,7 +38,6 @@ static bool MapIsNull(DataChunk &chunk) { } static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { - // internal MAP representation // - LIST-vector that contains STRUCTs as child entries // - STRUCTs have exactly two fields, a key-field, and a value-field @@ -107,7 +106,6 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { idx_t offset = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto keys_idx = keys_data.sel->get_index(row_idx); auto values_idx = values_data.sel->get_index(row_idx); auto result_idx = result_data.sel->get_index(row_idx); @@ -128,7 +126,6 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // set the selection vectors and perform a duplicate key check value_set_t unique_keys; for (idx_t child_idx = 0; child_idx < keys_entry.length; child_idx++) { - auto key_idx = keys_child_data.sel->get_index(keys_entry.offset + child_idx); auto value_idx = values_child_data.sel->get_index(values_entry.offset + child_idx); @@ -173,16 +170,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { } ScalarFunctionSet MapFun::GetFunctions() { - ScalarFunction empty_func({}, LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL), MapFunction); - BaseScalarFunction::SetReturnsError(empty_func); + empty_func.SetFallible(); auto key_type = LogicalType::TEMPLATE("K"); auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction value_func({LogicalType::LIST(key_type), LogicalType::LIST(val_type)}, LogicalType::MAP(key_type, val_type), MapFunction); - BaseScalarFunction::SetReturnsError(value_func); - value_func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + value_func.SetFallible(); + value_func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); ScalarFunctionSet set; diff --git a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp index 4c733d56f..33fac37ec 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp @@ -132,7 +132,6 @@ bool IsEmptyMap(const LogicalType &map) { unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto arg_count = arguments.size(); if (arg_count < 2) { throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); @@ -141,7 +140,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { // Prepared statement bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + bound_function.SetReturnType(LogicalTypeId::SQLNULL); return nullptr; } @@ -155,7 +154,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (map.id() == LogicalTypeId::UNKNOWN) { // Prepared statement bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); + bound_function.SetReturnType(LogicalTypeId::SQLNULL); return nullptr; } if (map.id() == LogicalTypeId::SQLNULL) { @@ -183,8 +182,8 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); } - bound_function.return_type = expected; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(expected); + return make_uniq(bound_function.GetReturnType()); } } // namespace @@ -192,7 +191,7 @@ unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &b ScalarFunction MapConcatFun::GetFunction() { //! the arguments and return types are actually set in the binder function ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp index 06af34e66..0d9372903 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp @@ -29,14 +29,13 @@ static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector & } ScalarFunction MapEntriesFun::GetFunction() { - auto key_type = LogicalType::TEMPLATE("K"); auto val_type = LogicalType::TEMPLATE("V"); auto map_type = LogicalType::MAP(key_type, val_type); auto row_type = LogicalType::STRUCT({{"key", key_type}, {"value", val_type}}); ScalarFunction fun({map_type}, LogicalType::LIST(row_type), MapEntriesFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp index fcea0b133..b7b8a3091 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp @@ -118,7 +118,7 @@ ScalarFunction MapExtractValueFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, val_type, MapExtractValueFunc); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } @@ -128,7 +128,7 @@ ScalarFunction MapExtractFun::GetFunction() { ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, LogicalType::LIST(val_type), MapExtractListFunc); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp index 2344b9a6e..169b9177c 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp @@ -26,9 +26,9 @@ ScalarFunction MapFromEntriesFun::GetFunction() { auto row_type = LogicalType::STRUCT({{"", key_type}, {"", val_type}}); ScalarFunction fun({LogicalType::LIST(row_type)}, map_type, MapFromEntriesFunction); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::DEFAULT_NULL_HANDLING); - BaseScalarFunction::SetReturnsError(fun); + fun.SetFallible(); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp index eec32a0a6..2ee626ba1 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp @@ -57,9 +57,9 @@ ScalarFunction MapKeysFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(key_type), MapKeysFunction); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -68,9 +68,9 @@ ScalarFunction MapValuesFun::GetFunction() { auto val_type = LogicalType::TEMPLATE("V"); ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(val_type), MapValuesFunction); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp index d6ae71bc0..799d4cb35 100644 --- a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp +++ b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp @@ -131,7 +131,7 @@ static unique_ptr PropagateAbsStats(ClientContext &context, Func } new_min = Value::Numeric(expr.return_type, min_val); new_max = Value::Numeric(expr.return_type, max_val); - expr.function.function = ScalarFunction::GetScalarUnaryFunction(expr.return_type); + expr.function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(expr.return_type)); } auto stats = NumericStats::CreateEmpty(expr.return_type); NumericStats::SetMin(stats, new_min); @@ -141,25 +141,25 @@ static unique_ptr PropagateAbsStats(ClientContext &context, Func } template -unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +static unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { auto decimal_type = arguments[0]->return_type; switch (decimal_type.InternalType()) { case PhysicalType::INT16: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); + bound_function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT)); break; case PhysicalType::INT32: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); + bound_function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER)); break; case PhysicalType::INT64: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); + bound_function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT)); break; default: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); + bound_function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT)); break; } bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; + bound_function.SetReturnType(decimal_type); return nullptr; } @@ -176,7 +176,7 @@ ScalarFunctionSet AbsOperatorFun::GetFunctions() { case LogicalTypeId::BIGINT: case LogicalTypeId::HUGEINT: { ScalarFunction function({type}, type, ScalarFunction::GetScalarUnaryFunction(type)); - function.statistics = PropagateAbsStats; + function.SetStatisticsCallback(PropagateAbsStats); abs.AddFunction(function); break; } @@ -192,7 +192,7 @@ ScalarFunctionSet AbsOperatorFun::GetFunctions() { } } for (auto &func : abs.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return abs; } @@ -338,25 +338,25 @@ static unique_ptr BindGenericRoundFunctionDecimal(ClientContext &c auto scale = DecimalType::GetScale(decimal_type); auto width = DecimalType::GetWidth(decimal_type); if (scale == 0) { - bound_function.function = ScalarFunction::NopFunction; + bound_function.SetFunctionCallback(ScalarFunction::NopFunction); } else { switch (decimal_type.InternalType()) { case PhysicalType::INT16: - bound_function.function = GenericRoundFunctionDecimal; + bound_function.SetFunctionCallback(GenericRoundFunctionDecimal); break; case PhysicalType::INT32: - bound_function.function = GenericRoundFunctionDecimal; + bound_function.SetFunctionCallback(GenericRoundFunctionDecimal); break; case PhysicalType::INT64: - bound_function.function = GenericRoundFunctionDecimal; + bound_function.SetFunctionCallback(GenericRoundFunctionDecimal); break; default: - bound_function.function = GenericRoundFunctionDecimal; + bound_function.SetFunctionCallback(GenericRoundFunctionDecimal); break; } } bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, 0); + bound_function.SetReturnType(LogicalType::DECIMAL(width, 0)); return nullptr; } @@ -482,13 +482,13 @@ struct RoundPrecisionFunctionData : public FunctionData { }; template -static void GenericRoundPrecisionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { +void GenericRoundPrecisionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { OP::template Operation(input, state, result); } template -static unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { auto &decimal_type = arguments[0]->return_type; if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); @@ -514,43 +514,43 @@ static unique_ptr BindDecimalRoundPrecision(ClientContext &context target_scale = 0; switch (decimal_type.InternalType()) { case PhysicalType::INT16: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; case PhysicalType::INT32: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; case PhysicalType::INT64: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; default: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; } } else { if (round_value >= (int32_t)scale) { // if round_value is bigger than or equal to scale we do nothing - bound_function.function = ScalarFunction::NopFunction; + bound_function.SetFunctionCallback(ScalarFunction::NopFunction); target_scale = scale; } else { target_scale = NumericCast(round_value); switch (decimal_type.InternalType()) { case PhysicalType::INT16: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; case PhysicalType::INT32: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; case PhysicalType::INT64: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; default: - bound_function.function = GenericRoundPrecisionDecimal; + bound_function.SetFunctionCallback(GenericRoundPrecisionDecimal); break; } } } bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, target_scale); + bound_function.SetReturnType(LogicalType::DECIMAL(width, target_scale)); return make_uniq(round_value); } @@ -972,7 +972,7 @@ struct SqrtOperator { ScalarFunction SqrtFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1017,7 +1017,7 @@ struct LnOperator { ScalarFunction LnFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1044,7 +1044,7 @@ struct Log10Operator { ScalarFunction Log10Fun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1073,7 +1073,7 @@ ScalarFunctionSet LogFun::GetFunctions() { funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::BinaryFunction)); for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return funcs; } @@ -1099,7 +1099,7 @@ struct Log2Operator { ScalarFunction Log2Fun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1289,7 +1289,7 @@ struct SinOperator { ScalarFunction SinFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1308,7 +1308,7 @@ struct CosOperator { ScalarFunction CosFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1327,7 +1327,7 @@ struct TanOperator { ScalarFunction TanFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1349,7 +1349,7 @@ struct ASinOperator { ScalarFunction AsinFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1405,7 +1405,7 @@ struct ACos { ScalarFunction AcosFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1515,7 +1515,7 @@ struct AtanhOperator { ScalarFunction AtanhFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1550,7 +1550,7 @@ struct CotOperator { ScalarFunction CotFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1572,7 +1572,7 @@ struct GammaOperator { ScalarFunction GammaFun::GetFunction() { auto func = ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } @@ -1594,7 +1594,7 @@ struct LogGammaOperator { ScalarFunction LogGammaFun::GetFunction() { ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1619,7 +1619,7 @@ struct FactorialOperator { ScalarFunction FactorialOperatorFun::GetFunction() { ScalarFunction function({LogicalType::INTEGER}, LogicalType::HUGEINT, ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -1735,7 +1735,7 @@ ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, ScalarFunction::BinaryFunction)); for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return funcs; } diff --git a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp index 56844b0f1..9e65138c2 100644 --- a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp +++ b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp @@ -116,7 +116,7 @@ ScalarFunctionSet BitwiseAndFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -153,7 +153,7 @@ ScalarFunctionSet BitwiseOrFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -190,7 +190,7 @@ ScalarFunctionSet BitwiseXorFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -225,7 +225,7 @@ ScalarFunctionSet BitwiseNotFun::GetFunctions() { } functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -294,7 +294,7 @@ ScalarFunctionSet LeftShiftFun::GetFunctions() { functions.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } @@ -344,7 +344,7 @@ ScalarFunctionSet RightShiftFun::GetFunctions() { functions.AddFunction( ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); } return functions; } diff --git a/src/duckdb/extension/core_functions/scalar/random/random.cpp b/src/duckdb/extension/core_functions/scalar/random/random.cpp index 589e264b4..738556161 100644 --- a/src/duckdb/extension/core_functions/scalar/random/random.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/random.cpp @@ -114,7 +114,7 @@ void GenerateUUIDv7Function(DataChunk &args, ExpressionState &state, Vector &res ScalarFunction RandomFun::GetFunction() { ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, RandomInitLocalState); - random.stability = FunctionStability::VOLATILE; + random.SetStability(FunctionStability::VOLATILE); return random; } @@ -126,7 +126,7 @@ ScalarFunction UUIDv4Fun::GetFunction() { ScalarFunction uuid_v4_function({}, LogicalType::UUID, GenerateUUIDv4Function, nullptr, nullptr, nullptr, RandomInitLocalState); // generate a random uuid v4 - uuid_v4_function.stability = FunctionStability::VOLATILE; + uuid_v4_function.SetStability(FunctionStability::VOLATILE); return uuid_v4_function; } @@ -134,7 +134,7 @@ ScalarFunction UUIDv7Fun::GetFunction() { ScalarFunction uuid_v7_function({}, LogicalType::UUID, GenerateUUIDv7Function, nullptr, nullptr, nullptr, RandomInitLocalState); // generate a random uuid v7 - uuid_v7_function.stability = FunctionStability::VOLATILE; + uuid_v7_function.SetStability(FunctionStability::VOLATILE); return uuid_v7_function; } diff --git a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp index 29072de56..1364b7ddf 100644 --- a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp @@ -58,8 +58,8 @@ unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bou ScalarFunction SetseedFun::GetFunction() { ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); - setseed.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(setseed); + setseed.SetVolatile(); + setseed.SetFallible(); return setseed; } diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp index d3d6eee7b..6ce5db5ce 100644 --- a/src/duckdb/extension/core_functions/scalar/string/hex.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/hex.cpp @@ -89,7 +89,6 @@ struct HexStrOperator { struct HexIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -119,7 +118,6 @@ struct HexIntegralOperator { struct HexHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); @@ -146,7 +144,6 @@ struct HexHugeIntOperator { struct HexUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); @@ -204,7 +201,6 @@ struct BinaryStrOperator { struct BinaryIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -409,7 +405,7 @@ ScalarFunctionSet HexFun::GetFunctions() { ScalarFunction UnhexFun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -433,7 +429,7 @@ ScalarFunctionSet BinFun::GetFunctions() { ScalarFunction UnbinFun::GetFunction() { ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/string/instr.cpp b/src/duckdb/extension/core_functions/scalar/string/instr.cpp index cc0fde9f1..47797d914 100644 --- a/src/duckdb/extension/core_functions/scalar/string/instr.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/instr.cpp @@ -44,7 +44,8 @@ static unique_ptr InStrPropagateStats(ClientContext &context, Fu // can only propagate stats if the children have stats // for strpos, we only care if the FIRST string has unicode or not if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; + expr.function.SetFunctionCallback( + ScalarFunction::BinaryFunction); } return nullptr; } @@ -53,7 +54,7 @@ ScalarFunction InstrFun::GetFunction() { auto function = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, ScalarFunction::BinaryFunction, nullptr, nullptr, InStrPropagateStats); - function.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + function.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return function; } diff --git a/src/duckdb/extension/core_functions/scalar/string/pad.cpp b/src/duckdb/extension/core_functions/scalar/string/pad.cpp index 586e1605a..44fb8a763 100644 --- a/src/duckdb/extension/core_functions/scalar/string/pad.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/pad.cpp @@ -133,14 +133,14 @@ static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) ScalarFunction LpadFun::GetFunction() { ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, PadFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } ScalarFunction RpadFun::GetFunction() { ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, PadFunction); - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp index 1ec8ae2cd..b98512d51 100644 --- a/src/duckdb/extension/core_functions/scalar/string/printf.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/printf.cpp @@ -189,7 +189,7 @@ ScalarFunction PrintfFun::GetFunction() { ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, PrintfFunction, BindPrintfFunction); printf_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(printf_fun); + printf_fun.SetFallible(); return printf_fun; } @@ -198,7 +198,7 @@ ScalarFunction FormatFun::GetFunction() { ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, PrintfFunction, BindPrintfFunction); format_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(format_fun); + format_fun.SetFallible(); return format_fun; } diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp index 2bfceae03..c93bbfa5e 100644 --- a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp @@ -67,7 +67,7 @@ ScalarFunctionSet RepeatFun::GetFunctions() { repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, LogicalType::LIST(LogicalType::TEMPLATE("T")), RepeatListFunction)); for (auto &func : repeat.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return repeat; } diff --git a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp index 7ef277292..dd918c8a3 100644 --- a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp @@ -17,7 +17,6 @@ static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const } static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); auto haystack_size = haystack_s.GetSize(); auto needle = const_uchar_ptr_cast(needle_s.GetData()); @@ -39,7 +38,7 @@ struct StartsWithOperator { ScalarFunction StartsWithOperatorFun::GetFunction() { ScalarFunction starts_with({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - starts_with.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + starts_with.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return starts_with; } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp index cc4fd6f01..7fade78ba 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp @@ -68,8 +68,8 @@ static unique_ptr StructInsertBind(ClientContext &context, ScalarF new_children.push_back(make_pair(child->GetAlias(), arguments[i]->return_type)); } - bound_function.return_type = LogicalType::STRUCT(new_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(new_children)); + return make_uniq(bound_function.GetReturnType()); } static unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -93,10 +93,10 @@ static unique_ptr StructInsertStats(ClientContext &context, Func ScalarFunction StructInsertFun::GetFunction() { ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; + fun.SetSerializeCallback(VariableReturnBindData::Serialize); + fun.SetDeserializeCallback(VariableReturnBindData::Deserialize); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp new file mode 100644 index 000000000..f021408e8 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp @@ -0,0 +1,94 @@ +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "core_functions/scalar/struct_functions.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +struct StructKeysBindData : public FunctionData { + const LogicalType type; + Vector keys_vector; + + explicit StructKeysBindData(const LogicalType &type_p) + : type(type_p), keys_vector(LogicalType::LIST(LogicalType::VARCHAR), 2) { + const auto &child_types = StructType::GetChildTypes(type); + const auto count = child_types.size(); + + ListVector::Reserve(keys_vector, count); + auto &list_child = ListVector::GetEntry(keys_vector); + auto child_data = FlatVector::GetData(list_child); + for (idx_t i = 0; i < count; i++) { + child_data[i] = StringVector::AddString(list_child, child_types[i].first); + } + ListVector::SetListSize(keys_vector, count); + + auto list_entries = FlatVector::GetData(keys_vector); + list_entries[0] = {0, count}; + + auto &validity = FlatVector::Validity(keys_vector); + validity.EnsureWritable(); + validity.SetInvalid(1); + } + + bool Equals(const FunctionData &other) const override { + auto &o = other.Cast(); + // Compare type and flag (content is derived from them) + return type == o.type; + } + + unique_ptr Copy() const override { + return make_uniq(type); + } +}; + +static void StructKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &input = args.data[0]; + const idx_t count = args.size(); + + auto &data = state.expr.Cast().bind_info->Cast(); + auto &keys_vector = data.keys_vector; + + // If the input is a constant, we must return a CONSTANT_VECTOR + if (args.AllConstant()) { + if (ConstantVector::IsNull(input)) { + ConstantVector::SetNull(result, true); + return; + } + ConstantVector::Reference(result, keys_vector, 0, count); + return; + } + + // Non-constant input: return a DICTIONARY_VECTOR over two entries (keys list and NULL) to preserve per-row NULLs + // Build the dictionary selection: 0 for non-null input, 1 for null input + SelectionVector sel(count); + UnifiedVectorFormat input_data; + input.ToUnifiedFormat(count, input_data); + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + const bool is_valid = input_data.validity.RowIsValid(idx); + sel.set_index(i, !is_valid); + } + + result.Slice(keys_vector, sel, count); +} + +static unique_ptr StructKeysBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments[0]->return_type.id() != LogicalTypeId::STRUCT) { + throw InvalidInputException("struct_keys() expects a STRUCT argument"); + } + + const bool is_unnamed = StructType::IsUnnamed(arguments[0]->return_type); + if (is_unnamed) { + throw InvalidInputException("struct_keys() cannot be applied to an unnamed STRUCT"); + } + return make_uniq(arguments[0]->return_type); +} + +ScalarFunction StructKeysFun::GetFunction() { + ScalarFunction func({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), StructKeysFunction, + StructKeysBind); + return func; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp index e83c9b884..0c099c8b9 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp @@ -108,8 +108,8 @@ static unique_ptr StructUpdateBind(ClientContext &context, ScalarF } } - bound_function.return_type = LogicalType::STRUCT(new_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(new_children)); + return make_uniq(bound_function.GetReturnType()); } unique_ptr StructUpdateStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -151,10 +151,10 @@ unique_ptr StructUpdateStats(ClientContext &context, FunctionSta ScalarFunction StructUpdateFun::GetFunction() { ScalarFunction fun({}, LogicalTypeId::STRUCT, StructUpdateFunction, StructUpdateBind, nullptr, StructUpdateStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.varargs = LogicalType::ANY; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; + fun.SetSerializeCallback(VariableReturnBindData::Serialize); + fun.SetDeserializeCallback(VariableReturnBindData::Deserialize); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp new file mode 100644 index 000000000..3c247ef88 --- /dev/null +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp @@ -0,0 +1,81 @@ +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" // VariableReturnBindData +#include "core_functions/scalar/struct_functions.hpp" + +namespace duckdb { + +static void StructValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + auto &input = args.data[0]; + const idx_t count = args.size(); + + auto &input_children = StructVector::GetEntries(input); + auto &result_children = StructVector::GetEntries(result); + D_ASSERT(result_children.size() == input_children.size()); + + // UnnamedStruct vector and Struct vector are actually the same underneath, so we can just reference the children + if (StructType::IsUnnamed(input.GetType())) { + result.Reference(input); + return; + } + + // We would use result.Reference(input) also for this case, + // but that function asserts that the logical types are the same + for (idx_t i = 0; i < input_children.size(); i++) { + result_children[i]->Reference(*input_children[i]); + } + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + const bool is_null = ConstantVector::IsNull(input); + ConstantVector::SetNull(result, is_null); + } else { + result.SetVectorType(VectorType::FLAT_VECTOR); + + // Make result validity to mirror input's nulls + UnifiedVectorFormat input_data; + input.ToUnifiedFormat(count, input_data); + + if (!input_data.validity.AllValid()) { + auto &validity = FlatVector::Validity(result); + + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + validity.SetInvalid(i); + } + } + } + } +} + +// Ensure input is a STRUCT, set return type to an unnamed STRUCT with same child types +static unique_ptr StructValuesBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + const auto arg_type = arguments[0]->return_type; + if (arg_type == LogicalTypeId::UNKNOWN) { + throw ParameterNotResolvedException(); + } + + // Since the type of the argument we declared of in `GetFunction` doesn't contain the inner STRUCT type, + // we should take it from the arguments + bound_function.arguments[0] = arg_type; + + // Build unnamed children list using only types, with empty names + child_list_t unnamed_children; + auto &children = StructType::GetChildTypes(arguments[0]->return_type); + unnamed_children.reserve(children.size()); + for (auto &child : children) { + unnamed_children.emplace_back("", child.second); + } + bound_function.SetReturnType(LogicalType::STRUCT(unnamed_children)); + return nullptr; +} + +ScalarFunction StructValuesFun::GetFunction() { + ScalarFunction func({LogicalTypeId::STRUCT}, LogicalTypeId::STRUCT, StructValuesFunction, StructValuesBind); + return func; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp index b322f18ea..3feed5a2a 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp @@ -97,7 +97,7 @@ unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction throw BinderException("Could not find key \"%s\" in union\n%s", key, message); } - bound_function.return_type = return_type; + bound_function.SetReturnType(return_type); return make_uniq(key, key_index, return_type); } diff --git a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp index 95f63590a..98f210fa3 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp @@ -10,7 +10,6 @@ namespace { unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.empty()) { throw BinderException("Missing required arguments for union_tag function."); } @@ -42,7 +41,7 @@ unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bo str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); } auto enum_type = LogicalType::ENUM(varchar_vector, member_count); - bound_function.return_type = enum_type; + bound_function.SetReturnType(enum_type); return nullptr; } diff --git a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp index 44274b3fd..b177a9b4e 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp @@ -40,7 +40,6 @@ void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.size() != 1) { throw BinderException("union_value takes exactly one argument"); } @@ -54,8 +53,8 @@ unique_ptr UnionValueBind(ClientContext &context, ScalarFunction & union_members.push_back(make_pair(child->GetAlias(), child->return_type)); - bound_function.return_type = LogicalType::UNION(std::move(union_members)); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::UNION(std::move(union_members))); + return make_uniq(bound_function.GetReturnType()); } } // namespace @@ -63,9 +62,9 @@ unique_ptr UnionValueBind(ClientContext &context, ScalarFunction & ScalarFunction UnionValueFun::GetFunction() { ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetSerializeCallback(VariableReturnBindData::Serialize); + fun.SetDeserializeCallback(VariableReturnBindData::Deserialize); return fun; } diff --git a/src/duckdb/extension/icu/icu-current.cpp b/src/duckdb/extension/icu/icu-current.cpp index 65bf29c54..76a7ae0f3 100644 --- a/src/duckdb/extension/icu/icu-current.cpp +++ b/src/duckdb/extension/icu/icu-current.cpp @@ -36,13 +36,13 @@ static void CurrentDateFunction(DataChunk &input, ExpressionState &state, Vector ScalarFunction GetCurrentTimeFun() { ScalarFunction current_time({}, LogicalType::TIME_TZ, CurrentTimeFunction); - current_time.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_time.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_time; } ScalarFunction GetCurrentDateFun() { ScalarFunction current_date({}, LogicalType::DATE, CurrentDateFunction); - current_date.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; + current_date.SetStability(FunctionStability::CONSISTENT_WITHIN_QUERY); return current_date; } diff --git a/src/duckdb/extension/icu/icu-dateadd.cpp b/src/duckdb/extension/icu/icu-dateadd.cpp index 7f979f8a5..56a025861 100644 --- a/src/duckdb/extension/icu/icu-dateadd.cpp +++ b/src/duckdb/extension/icu/icu-dateadd.cpp @@ -219,7 +219,6 @@ interval_t ICUCalendarAge::Operation(timestamp_t end_date, timestamp_t start_dat } struct ICUDateAdd : public ICUDateFunc { - template static void ExecuteUnary(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); diff --git a/src/duckdb/extension/icu/icu-datefunc.cpp b/src/duckdb/extension/icu/icu-datefunc.cpp index 2d5fdce78..b0924b83a 100644 --- a/src/duckdb/extension/icu/icu-datefunc.cpp +++ b/src/duckdb/extension/icu/icu-datefunc.cpp @@ -16,7 +16,6 @@ ICUDateFunc::BindData::BindData(const BindData &other) ICUDateFunc::BindData::BindData(const string &tz_setting_p, const string &cal_setting_p) : tz_setting(tz_setting_p), cal_setting(cal_setting_p) { - InitCalendar(); } diff --git a/src/duckdb/extension/icu/icu-datepart.cpp b/src/duckdb/extension/icu/icu-datepart.cpp index 570430283..445fe35c3 100644 --- a/src/duckdb/extension/icu/icu-datepart.cpp +++ b/src/duckdb/extension/icu/icu-datepart.cpp @@ -500,8 +500,8 @@ struct ICUDatePart : public ICUDateFunc { arguments.erase(arguments.begin()); bound_function.arguments.erase(bound_function.arguments.begin()); bound_function.name = part_name; - bound_function.return_type = LogicalType::DOUBLE; - bound_function.function = UnaryTimestampFunction; + bound_function.SetReturnType(LogicalType::DOUBLE); + bound_function.SetFunctionCallback(UnaryTimestampFunction); return BindUnaryDatePart(context, bound_function, arguments); } while (false); @@ -554,7 +554,7 @@ struct ICUDatePart : public ICUDateFunc { } Function::EraseArgument(bound_function, arguments, 0); - bound_function.return_type = LogicalType::STRUCT(std::move(struct_children)); + bound_function.SetReturnType(LogicalType::STRUCT(std::move(struct_children))); return make_uniq(context, std::move(part_codes)); } @@ -601,8 +601,8 @@ struct ICUDatePart : public ICUDateFunc { auto part_type = LogicalType::LIST(LogicalType::VARCHAR); auto result_type = LogicalType::STRUCT({}); ScalarFunction result({part_type, temporal_type}, result_type, StructFunction, BindStruct); - result.serialize = SerializeStructFunction; - result.deserialize = DeserializeStructFunction; + result.SetSerializeCallback(SerializeStructFunction); + result.SetDeserializeCallback(DeserializeStructFunction); return result; } @@ -611,7 +611,7 @@ struct ICUDatePart : public ICUDateFunc { set.AddFunction(GetBinaryPartCodeFunction(LogicalType::TIMESTAMP_TZ)); set.AddFunction(GetStructFunction(LogicalType::TIMESTAMP_TZ)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-datesub.cpp b/src/duckdb/extension/icu/icu-datesub.cpp index 00e14f9e1..9c5edbcc1 100644 --- a/src/duckdb/extension/icu/icu-datesub.cpp +++ b/src/duckdb/extension/icu/icu-datesub.cpp @@ -9,7 +9,6 @@ namespace duckdb { struct ICUCalendarSub : public ICUDateFunc { - // ICU only has 32 bit precision for date parts, so it can overflow a high resolution. // Since there is no difference between ICU and the obvious calculations, // we make these using the DuckDB internal type. @@ -192,7 +191,6 @@ ICUDateFunc::part_sub_t ICUDateFunc::SubtractFactory(DatePartSpecifier type) { // MS-SQL differences can be computed using ICU by truncating both arguments // to the desired part precision and then applying ICU subtraction/difference struct ICUCalendarDiff : public ICUDateFunc { - template static int64_t DifferenceFunc(icu::Calendar *calendar, timestamp_t start_date, timestamp_t end_date, part_trunc_t trunc_func, part_sub_t sub_func) { diff --git a/src/duckdb/extension/icu/icu-list-range.cpp b/src/duckdb/extension/icu/icu-list-range.cpp index 4ee9e0b46..a1ec558e2 100644 --- a/src/duckdb/extension/icu/icu-list-range.cpp +++ b/src/duckdb/extension/icu/icu-list-range.cpp @@ -181,7 +181,6 @@ struct ICUListRange : public ICUDateFunc { } static void AddICUListRangeFunction(ExtensionLoader &loader) { - ScalarFunctionSet range("range"); range.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL}, LogicalType::LIST(LogicalType::TIMESTAMP_TZ), ICUListRangeFunction, diff --git a/src/duckdb/extension/icu/icu-makedate.cpp b/src/duckdb/extension/icu/icu-makedate.cpp index 7c8efb2cb..128e80d93 100644 --- a/src/duckdb/extension/icu/icu-makedate.cpp +++ b/src/duckdb/extension/icu/icu-makedate.cpp @@ -145,7 +145,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { static ScalarFunction GetSenaryFunction(const LogicalTypeId &type) { ScalarFunction function({type, type, type, type, type, LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, Execute, Bind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -153,7 +153,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { static ScalarFunction GetSeptenaryFunction(const LogicalTypeId &type) { ScalarFunction function({type, type, type, type, type, LogicalType::DOUBLE, LogicalType::VARCHAR}, LogicalType::TIMESTAMP_TZ, Execute, Bind); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } @@ -162,7 +162,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { set.AddFunction(GetSenaryFunction(LogicalType::BIGINT)); set.AddFunction(GetSeptenaryFunction(LogicalType::BIGINT)); ScalarFunction function({LogicalType::BIGINT}, LogicalType::TIMESTAMP_TZ, FromMicros); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); set.AddFunction(function); loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-strptime.cpp b/src/duckdb/extension/icu/icu-strptime.cpp index 63e383dcd..f2ca158c2 100644 --- a/src/duckdb/extension/icu/icu-strptime.cpp +++ b/src/duckdb/extension/icu/icu-strptime.cpp @@ -203,8 +203,8 @@ struct ICUStrptime : public ICUDateFunc { // If we have a time zone, we should use ICU for parsing and return a TSTZ instead. if (format.HasFormatSpecifier(StrTimeSpecifier::TZ_NAME)) { - bound_function.function = function; - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetFunctionCallback(function); + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); return make_uniq(context, format); } } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { @@ -227,14 +227,14 @@ struct ICUStrptime : public ICUDateFunc { formats.emplace_back(format); } if (has_tz) { - bound_function.function = function; - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetFunctionCallback(function); + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); return make_uniq(context, formats); } } // Fall back to faster, non-TZ parsing - bound_function.bind = bind_strptime; + bound_function.SetBindCallback(bind_strptime); return bind_strptime(context, bound_function, arguments); } @@ -254,8 +254,8 @@ struct ICUStrptime : public ICUDateFunc { throw InternalException("ICU - Function for TailPatch not found"); } auto &bound_function = functions[best_index.GetIndex()]; - bind_strptime = bound_function.bind; - bound_function.bind = StrpTimeBindFunction; + bind_strptime = bound_function.GetBindCallback(); + bound_function.SetBindCallback(StrpTimeBindFunction); } static void AddBinaryTimestampFunction(const string &name, ExtensionLoader &loader) { diff --git a/src/duckdb/extension/icu/icu-timebucket.cpp b/src/duckdb/extension/icu/icu-timebucket.cpp index 1336e0189..9a4035d18 100644 --- a/src/duckdb/extension/icu/icu-timebucket.cpp +++ b/src/duckdb/extension/icu/icu-timebucket.cpp @@ -16,7 +16,6 @@ namespace duckdb { struct ICUTimeBucket : public ICUDateFunc { - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility // There are 10959 days between 1970-01-01 and 2000-01-03 constexpr static const int64_t DEFAULT_ORIGIN_MICROS_1 = 10959 * Interval::MICROS_PER_DAY; @@ -630,7 +629,7 @@ struct ICUTimeBucket : public ICUDateFunc { set.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP_TZ, LogicalType::VARCHAR}, LogicalType::TIMESTAMP_TZ, ICUTimeBucketTimeZoneFunction, Bind)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-timezone.cpp b/src/duckdb/extension/icu/icu-timezone.cpp index 86b8b6033..65993beaf 100644 --- a/src/duckdb/extension/icu/icu-timezone.cpp +++ b/src/duckdb/extension/icu/icu-timezone.cpp @@ -267,7 +267,6 @@ struct ICUToNaiveTimestamp : public ICUDateFunc { }; struct ICULocalTimestampFunc : public ICUDateFunc { - struct BindDataNow : public BindData { explicit BindDataNow(ClientContext &context) : BindData(context) { now = MetaTransaction::Get(context).start_timestamp; @@ -452,7 +451,7 @@ struct ICUTimeZoneFunc : public ICUDateFunc { set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME_TZ}, LogicalType::TIME_TZ, Execute, Bind)); for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu_extension.cpp b/src/duckdb/extension/icu/icu_extension.cpp index 006283576..c4e02d0ac 100644 --- a/src/duckdb/extension/icu/icu_extension.cpp +++ b/src/duckdb/extension/icu/icu_extension.cpp @@ -5,11 +5,8 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/connection.hpp" -#include "duckdb/main/database.hpp" #include "duckdb/main/extension/extension_loader.hpp" #include "duckdb/parser/parsed_data/create_collation_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "include/icu-current.hpp" #include "include/icu-dateadd.hpp" @@ -25,8 +22,6 @@ #include "include/icu_extension.hpp" #include "unicode/calendar.h" #include "unicode/coll.h" -#include "unicode/errorcode.h" -#include "unicode/sortkey.h" #include "unicode/stringpiece.h" #include "unicode/timezone.h" #include "unicode/ucol.h" @@ -204,12 +199,12 @@ static ScalarFunction GetICUCollateFunction(const string &collation, const strin ScalarFunction result(fname, {LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUCollateFunction, ICUCollateBind); //! collation tag is added into the Function extra info result.extra_info = tag; - result.serialize = IcuBindData::Serialize; - result.deserialize = IcuBindData::Deserialize; + result.SetSerializeCallback(IcuBindData::Serialize); + result.SetDeserializeCallback(IcuBindData::Deserialize); return result; } -unique_ptr GetTimeZoneInternal(string &tz_str, vector &candidates) { +unique_ptr GetKnownTimeZone(const string &tz_str) { icu::StringPiece tz_name_utf8(tz_str); const auto uid = icu::UnicodeString::fromUTF8(tz_name_utf8); duckdb::unique_ptr tz(icu::TimeZone::createTimeZone(uid)); @@ -217,6 +212,74 @@ unique_ptr GetTimeZoneInternal(string &tz_str, vector &ca return tz; } + return nullptr; +} + +static string NormalizeTimeZone(const string &tz_str) { + if (GetKnownTimeZone(tz_str)) { + return tz_str; + } + + // Map UTC±NN00 to Etc/UTC±N + do { + if (tz_str.size() <= 4) { + break; + } + if (tz_str.compare(0, 3, "UTC")) { + break; + } + + idx_t pos = 3; + const auto utc = tz_str[pos++]; + // Invert the sign (UTC and Etc use opposite sign conventions) + // https://en.wikipedia.org/wiki/Tz_database#Area + auto sign = utc; + if (utc == '+') { + sign = '-'; + ; + } else if (utc == '-') { + sign = '+'; + } else { + break; + } + + string mapped = "Etc/GMT"; + mapped += sign; + const auto base_len = mapped.size(); + for (; pos < tz_str.size(); ++pos) { + const auto digit = tz_str[pos]; + // We could get fancy here and count colons and their locations, but I doubt anyone cares. + if (digit == '0' || digit == ':') { + continue; + } + if (!StringUtil::CharacterIsDigit(digit)) { + break; + } + mapped += digit; + } + if (pos < tz_str.size()) { + break; + } + // If we didn't add anything, then make it +0 + if (mapped.size() == base_len) { + mapped.back() = '+'; + mapped += '0'; + } + // Final sanity check + if (GetKnownTimeZone(mapped)) { + return mapped; + } + } while (false); + + return tz_str; +} + +unique_ptr GetTimeZoneInternal(string &tz_str, vector &candidates) { + auto tz = GetKnownTimeZone(tz_str); + if (tz) { + return tz; + } + // Try to be friendlier // Go through all the zone names and look for a case insensitive match // If we don't find one, make a suggestion @@ -269,6 +332,7 @@ unique_ptr ICUHelpers::GetTimeZone(string &tz_str, string *error_ static void SetICUTimeZone(ClientContext &context, SetScope scope, Value ¶meter) { auto tz_str = StringValue::Get(parameter); + tz_str = NormalizeTimeZone(tz_str); ICUHelpers::GetTimeZone(tz_str); parameter = Value(tz_str); } @@ -362,18 +426,18 @@ static void SetICUCalendar(ClientContext &context, SetScope scope, Value ¶me } static void LoadInternal(ExtensionLoader &loader) { - // iterate over all the collations int32_t count; auto locales = icu::Collator::getAvailableLocales(count); for (int32_t i = 0; i < count; i++) { string collation; - if (string(locales[i].getCountry()).empty()) { + const auto &locale = locales[i]; // NOLINT + if (string(locale.getCountry()).empty()) { // language only - collation = locales[i].getLanguage(); + collation = locale.getLanguage(); } else { // language + country - collation = locales[i].getLanguage() + string("_") + locales[i].getCountry(); + collation = locale.getLanguage() + string("_") + locale.getCountry(); } collation = StringUtil::Lower(collation); @@ -405,6 +469,11 @@ static void LoadInternal(ExtensionLoader &loader) { icu::UnicodeString tz_id; std::string tz_string; tz->getID(tz_id).toUTF8String(tz_string); + // If the environment TZ is invalid, look for some alternatives + tz_string = NormalizeTimeZone(tz_string); + if (!GetKnownTimeZone(tz_string)) { + tz_string = "UTC"; + } config.AddExtensionOption("TimeZone", "The current time zone", LogicalType::VARCHAR, Value(tz_string), SetICUTimeZone); diff --git a/src/duckdb/extension/icu/third_party/icu/common/putil.cpp b/src/duckdb/extension/icu/third_party/icu/common/putil.cpp index c79811499..0c3fd9376 100644 --- a/src/duckdb/extension/icu/third_party/icu/common/putil.cpp +++ b/src/duckdb/extension/icu/third_party/icu/common/putil.cpp @@ -1090,9 +1090,15 @@ uprv_tzname(int n) if (tzid[0] == ':') { tzid++; } - /* This might be a good Olson ID. */ - skipZoneIDPrefix(&tzid); - return tzid; +#if defined(TZDEFAULT) + if (uprv_strcmp(tzid, TZDEFAULT) != 0) { +#endif + /* This might be a good Olson ID. */ + skipZoneIDPrefix(&tzid); + return tzid; +#if defined(TZDEFAULT) + } +#endif } /* else U_TZNAME will give a better result. */ #endif diff --git a/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h b/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h index e69de29bb..c1f295577 100644 --- a/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h +++ b/src/duckdb/extension/icu/third_party/icu/common/unicode/ucnv.h @@ -0,0 +1,11 @@ +/** + * Converter option for EBCDIC SBCS or mixed-SBCS/DBCS (stateful) codepages. + * Swaps Unicode mappings for EBCDIC LF and NL codes, as used on + * S/390 (z/OS) Unix System Services (Open Edition). + * For example, ucnv_open("ibm-1047,swaplfnl", &errorCode); + * See convrtrs.txt. + * + * @see ucnv_open + * @stable ICU 2.4 + */ +#define UCNV_SWAP_LFNL_OPTION_STRING ",swaplfnl" diff --git a/src/duckdb/extension/json/include/json_common.hpp b/src/duckdb/extension/json/include/json_common.hpp index f6dd78f05..81bbd6868 100644 --- a/src/duckdb/extension/json/include/json_common.hpp +++ b/src/duckdb/extension/json/include/json_common.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "yyjson.hpp" +#include "duckdb/common/types/blob.hpp" using namespace duckdb_yyjson; // NOLINT @@ -228,11 +229,8 @@ struct JSONCommon { static string FormatParseError(const char *data, idx_t length, yyjson_read_err &error, const string &extra = "") { D_ASSERT(error.code != YYJSON_READ_SUCCESS); - // Go to blob so we can have a better error message for weird strings - auto blob = Value::BLOB(string(data, length)); // Truncate, so we don't print megabytes worth of JSON - string input = blob.ToString(); - input = input.length() > 50 ? string(input.c_str(), 47) + "..." : input; + auto input = length > 50 ? string(data, 47) + "..." : string(data, length); // Have to replace \r, otherwise output is unreadable input = StringUtil::Replace(input, "\r", "\\r"); return StringUtil::Format("Malformed JSON at byte %lld of input: %s. %s Input: \"%s\"", error.pos, error.msg, diff --git a/src/duckdb/extension/json/include/json_reader.hpp b/src/duckdb/extension/json/include/json_reader.hpp index de75af996..b78da3e31 100644 --- a/src/duckdb/extension/json/include/json_reader.hpp +++ b/src/duckdb/extension/json/include/json_reader.hpp @@ -210,8 +210,8 @@ class JSONReader : public BaseFileReader { void PrepareReader(ClientContext &context, GlobalTableFunctionState &) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; diff --git a/src/duckdb/extension/json/include/json_serializer.hpp b/src/duckdb/extension/json/include/json_serializer.hpp index aa17f3ffd..e856bff79 100644 --- a/src/duckdb/extension/json/include/json_serializer.hpp +++ b/src/duckdb/extension/json/include/json_serializer.hpp @@ -39,6 +39,18 @@ struct JsonSerializer : Serializer { return serializer.GetRootObject(); } + template + static string SerializeToString(T &value) { + auto doc = yyjson_mut_doc_new(nullptr); + JsonSerializer serializer(doc, false, false, false); + value.Serialize(serializer); + auto result_obj = serializer.GetRootObject(); + idx_t len = 0; + auto data = yyjson_mut_val_write_opts(result_obj, JSONCommon::WRITE_PRETTY_FLAG, nullptr, + reinterpret_cast(&len), nullptr); + return string(data, len); + } + yyjson_mut_val *GetRootObject() { D_ASSERT(stack.size() == 1); // or we forgot to pop somewhere return stack.front(); diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index 2d09828c3..2d0ef11f5 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -394,7 +394,11 @@ void JSONFunctions::RegisterSimpleCastFunctions(ExtensionLoader &loader) { loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalTypeId::VARCHAR, CastJSONListToVarchar, json_list_to_varchar_cost); - // VARCHAR to JSON[] (also needs a special case otherwise get a VARCHAR -> VARCHAR[] cast first) + // JSON[] to JSON is allowed implicitly + loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalType::JSON(), CastJSONListToVarchar, + 100); + + // VARCHAR to JSON[] (also needs a special case otherwise we get a VARCHAR -> VARCHAR[] cast first) const auto varchar_to_json_list_cost = CastFunctionSet::ImplicitCastCost(db, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::JSON())) - 1; BoundCastInfo varchar_to_json_list_info(CastVarcharToJSONList, nullptr, JSONFunctionLocalState::InitCastLocalState); diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 4cd00249c..d1c8a8afb 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -111,11 +111,11 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct auto &type = arguments[i]->return_type; if (arguments[i]->HasParameter()) { throw ParameterNotResolvedException(); - } else if (type == LogicalTypeId::SQLNULL) { - // This is needed for macro's - bound_function.arguments.push_back(type); } else if (object && i % 2 == 0) { - // Key, must be varchar + if (type != LogicalType::VARCHAR) { + throw BinderException("json_object() keys must be VARCHAR, add an explicit cast to argument \"%s\"", + arguments[i]->GetName()); + } bound_function.arguments.push_back(LogicalType::VARCHAR); } else { // Value, cast to types that we can put in JSON @@ -128,7 +128,7 @@ static unique_ptr JSONCreateBindParams(ScalarFunction &bound_funct static unique_ptr JSONObjectBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() % 2 != 0) { - throw InvalidInputException("json_object() requires an even number of arguments"); + throw BinderException("json_object() requires an even number of arguments"); } return JSONCreateBindParams(bound_function, arguments, true); } @@ -141,7 +141,7 @@ static unique_ptr JSONArrayBind(ClientContext &context, ScalarFunc static unique_ptr ToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("to_json() takes exactly one argument"); + throw BinderException("to_json() takes exactly one argument"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -149,14 +149,14 @@ static unique_ptr ToJSONBind(ClientContext &context, ScalarFunctio static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("array_to_json() takes exactly one argument"); + throw BinderException("array_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arg_id != LogicalTypeId::LIST && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("array_to_json() argument type must be LIST"); + throw BinderException("array_to_json() argument type must be LIST"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -164,14 +164,14 @@ static unique_ptr ArrayToJSONBind(ClientContext &context, ScalarFu static unique_ptr RowToJSONBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 1) { - throw InvalidInputException("row_to_json() takes exactly one argument"); + throw BinderException("row_to_json() takes exactly one argument"); } auto arg_id = arguments[0]->return_type.id(); if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } if (arguments[0]->return_type.id() != LogicalTypeId::STRUCT && arg_id != LogicalTypeId::SQLNULL) { - throw InvalidInputException("row_to_json() argument type must be STRUCT"); + throw BinderException("row_to_json() argument type must be STRUCT"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -473,7 +473,6 @@ static void CreateValuesList(const StructNames &names, yyjson_mut_doc *doc, yyjs static void CreateValuesArray(const StructNames &names, yyjson_mut_doc *doc, yyjson_mut_val *vals[], Vector &value_v, idx_t count) { - value_v.Flatten(count); // Initialize array for the nested values @@ -616,6 +615,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::VALIDITY: case LogicalTypeId::TABLE: case LogicalTypeId::LAMBDA: + case LogicalTypeId::GEOMETRY: // TODO! Add support for GEOMETRY throw InternalException("Unsupported type arrived at JSON create function"); } } @@ -728,7 +728,7 @@ ScalarFunctionSet JSONFunctions::GetObjectFunction() { ScalarFunction fun("json_object", {}, LogicalType::JSON(), ObjectFunction, JSONObjectBind, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } @@ -736,7 +736,7 @@ ScalarFunctionSet JSONFunctions::GetArrayFunction() { ScalarFunction fun("json_array", {}, LogicalType::JSON(), ArrayFunction, JSONArrayBind, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } diff --git a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp index 225228924..de1caadc2 100644 --- a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp +++ b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp @@ -84,7 +84,7 @@ ScalarFunctionSet JSONFunctions::GetMergePatchFunction() { ScalarFunction fun("json_merge_patch", {LogicalType::JSON(), LogicalType::JSON()}, LogicalType::JSON(), MergePatchFunction, nullptr, nullptr, nullptr, JSONFunctionLocalState::Init); fun.varargs = LogicalType::JSON(); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunctionSet(fun); } diff --git a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp index 787e393b9..404164181 100644 --- a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp +++ b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp @@ -284,7 +284,6 @@ static void InitializeLocalState(JSONTableInOutLocalState &lstate, DataChunk &in template static bool JSONTableInOutHandleValue(JSONTableInOutLocalState &lstate, JSONTableInOutResult &result, idx_t &child_index, size_t &idx, yyjson_val *child_key, yyjson_val *child_val) { - if (idx < child_index) { return false; // Continue: Get back to where we left off } diff --git a/src/duckdb/extension/json/json_multi_file_info.cpp b/src/duckdb/extension/json/json_multi_file_info.cpp index 7771af489..1f131e6af 100644 --- a/src/duckdb/extension/json/json_multi_file_info.cpp +++ b/src/duckdb/extension/json/json_multi_file_info.cpp @@ -1,6 +1,7 @@ #include "json_multi_file_info.hpp" #include "json_scan.hpp" #include "duckdb/common/types/value.hpp" +#include "duckdb/parallel/async_result.hpp" namespace duckdb { @@ -530,8 +531,17 @@ void ReadJSONObjectsFunction(ClientContext &context, JSONReader &json_reader, JS output.SetCardinality(count); } -void JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &output) { +AsyncResult JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &output) { +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + { + vector> tasks = AsyncResult::GenerateTestTasks(); + if (!tasks.empty()) { + return AsyncResult(std::move(tasks)); + } + } +#endif + auto &gstate = global_state.Cast().state; auto &lstate = local_state.Cast().state; auto &json_data = gstate.bind_data.bind_data->Cast(); @@ -545,6 +555,7 @@ void JSONReader::Scan(ClientContext &context, GlobalTableFunctionState &global_s default: throw InternalException("Unsupported scan type for JSONMultiFileInfo::Scan"); } + return AsyncResult(output.size() ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED); } void JSONReader::FinishFile(ClientContext &context, GlobalTableFunctionState &global_state) { diff --git a/src/duckdb/extension/json/json_reader.cpp b/src/duckdb/extension/json/json_reader.cpp index b52026a4e..ad61da7d2 100644 --- a/src/duckdb/extension/json/json_reader.cpp +++ b/src/duckdb/extension/json/json_reader.cpp @@ -184,8 +184,7 @@ void JSONReader::OpenJSONFile() { if (!IsOpen()) { auto &fs = FileSystem::GetFileSystem(context); auto regular_file_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ | options.compression); - file_handle = make_uniq(QueryContext(context), std::move(regular_file_handle), - BufferAllocator::Get(context)); + file_handle = make_uniq(context, std::move(regular_file_handle), BufferAllocator::Get(context)); } Reset(); } diff --git a/src/duckdb/extension/loader/dummy_static_extension_loader.cpp b/src/duckdb/extension/loader/dummy_static_extension_loader.cpp new file mode 100644 index 000000000..9275653ea --- /dev/null +++ b/src/duckdb/extension/loader/dummy_static_extension_loader.cpp @@ -0,0 +1,15 @@ +#include "duckdb/main/extension_helper.hpp" + +// This is a dummy loader to produce a workable duckdb library without linking any extensions. +// Link this to libduckdb_static.a to get a working system. + +namespace duckdb { +void ExtensionHelper::LoadAllExtensions(DuckDB &db) { + // nop +} + +vector ExtensionHelper::LoadedExtensionTestPaths() { + return {}; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index c13a71b6f..5ba8b4dd4 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -28,6 +28,8 @@ #include "duckdb/common/helper.hpp" #include "duckdb/common/types/bit.hpp" +#include "parquet_crypto.hpp" + namespace duckdb { using duckdb_parquet::CompressionCodec; @@ -109,7 +111,7 @@ const uint8_t ParquetDecodeUtils::BITPACK_DLEN = 8; ColumnReader::ColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema_p) : column_schema(schema_p), reader(reader), page_rows_available(0), dictionary_decoder(*this), delta_binary_packed_decoder(*this), rle_decoder(*this), delta_length_byte_array_decoder(*this), - delta_byte_array_decoder(*this), byte_stream_split_decoder(*this) { + delta_byte_array_decoder(*this), byte_stream_split_decoder(*this), aad_crypto_metadata(reader.allocator) { } ColumnReader::~ColumnReader() { @@ -232,6 +234,36 @@ bool ColumnReader::PageIsFilteredOut(PageHeader &page_hdr) { return true; } +void ColumnReader::ReadEncrypted(duckdb_apache::thrift::TBase &object) { + aad_crypto_metadata.module = ParquetCrypto::GetModuleHeader(*chunk, aad_crypto_metadata.page_ordinal); + aad_crypto_metadata.page_ordinal = + ParquetCrypto::GetFinalPageOrdinal(*chunk, aad_crypto_metadata.module, aad_crypto_metadata.page_ordinal); + reader.ReadEncrypted(object, *protocol, aad_crypto_metadata); +} + +void ColumnReader::ReadDataEncrypted(const data_ptr_t buffer, const uint32_t buffer_size, PageType::type page_type) { + aad_crypto_metadata.module = ParquetCrypto::GetModule(*chunk, page_type, aad_crypto_metadata.page_ordinal); + aad_crypto_metadata.page_ordinal = + ParquetCrypto::GetFinalPageOrdinal(*chunk, aad_crypto_metadata.module, aad_crypto_metadata.page_ordinal); + reader.ReadDataEncrypted(*protocol, buffer, buffer_size, aad_crypto_metadata); +} + +void ColumnReader::Read(PageHeader &page_hdr) { + if (reader.parquet_options.encryption_config) { + ReadEncrypted(page_hdr); + } else { + reader.Read(page_hdr, *protocol); + } +} + +void ColumnReader::ReadData(const data_ptr_t buffer, const uint32_t buffer_size, PageType::type page_type) { + if (reader.parquet_options.encryption_config) { + ReadDataEncrypted(buffer, buffer_size, page_type); + } else { + reader.ReadData(*protocol, buffer, buffer_size); + } +} + void ColumnReader::PrepareRead(optional_ptr filter, optional_ptr filter_state) { encoding = ColumnEncoding::INVALID; defined_decoder.reset(); @@ -239,16 +271,17 @@ void ColumnReader::PrepareRead(optional_ptr filter, optional_ block.reset(); PageHeader page_hdr; auto &trans = reinterpret_cast(*protocol->getTransport()); + if (trans.HasPrefetch()) { // Already has some data prefetched, let's not mess with it - reader.Read(page_hdr, *protocol); + Read(page_hdr); } else { // No prefetch yet, prefetch the full header in one go (so thrift won't read byte-by-byte from storage) // 256 bytes should cover almost all headers (unless it's a V2 header with really LONG string statistics) static constexpr idx_t ASSUMED_HEADER_SIZE = 256; const auto prefetch_size = MinValue(trans.GetSize() - trans.GetLocation(), ASSUMED_HEADER_SIZE); trans.Prefetch(trans.GetLocation(), prefetch_size); - reader.Read(page_hdr, *protocol); + Read(page_hdr); trans.ClearPrefetch(); } // some basic sanity check @@ -304,7 +337,7 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { uncompressed = true; } if (uncompressed) { - reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size); + ReadData(block->ptr, page_hdr.compressed_page_size, page_hdr.type); return; } @@ -317,14 +350,16 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { "repetition_levels_byte_length + definition_levels_byte_length", Reader().GetFileName()); } - reader.ReadData(*protocol, block->ptr, uncompressed_bytes); + + ReadData(block->ptr, uncompressed_bytes, page_hdr.type); auto compressed_bytes = page_hdr.compressed_page_size - uncompressed_bytes; if (compressed_bytes > 0) { ResizeableBuffer compressed_buffer; compressed_buffer.resize(GetAllocator(), compressed_bytes); - reader.ReadData(*protocol, compressed_buffer.ptr, compressed_bytes); + + ReadData(compressed_buffer.ptr, compressed_bytes, page_hdr.type); DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, compressed_bytes, block->ptr + uncompressed_bytes, page_hdr.uncompressed_page_size - uncompressed_bytes); @@ -341,19 +376,32 @@ void ColumnReader::AllocateBlock(idx_t size) { void ColumnReader::PreparePage(PageHeader &page_hdr) { AllocateBlock(page_hdr.uncompressed_page_size + 1); + uint32_t compressed_page_size = page_hdr.compressed_page_size; + + if (chunk->__isset.crypto_metadata) { + auto const file_aad = reader.GetUniqueFileIdentifier(reader.metadata->crypto_metadata->encryption_algorithm); + if (!file_aad.empty()) { + // If there is a file aad (identifier), this means that the Encrypted file is written by Arrow + // Arrow adds the bytes for encryption (len + nonce + tag) + // to the compressed page size + compressed_page_size -= + (ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + ParquetCrypto::TAG_BYTES); + } + } + if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { - if (page_hdr.compressed_page_size != page_hdr.uncompressed_page_size) { + if (compressed_page_size != page_hdr.uncompressed_page_size) { throw std::runtime_error("Page size mismatch"); } - reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size); + ReadData(block->ptr, compressed_page_size, page_hdr.type); return; } ResizeableBuffer compressed_buffer; - compressed_buffer.resize(GetAllocator(), page_hdr.compressed_page_size + 1); - reader.ReadData(*protocol, compressed_buffer.ptr, page_hdr.compressed_page_size); + compressed_buffer.resize(GetAllocator(), compressed_page_size + 1); + ReadData(compressed_buffer.ptr, compressed_page_size, page_hdr.type); - DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, page_hdr.compressed_page_size, block->ptr, + DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, compressed_page_size, block->ptr, page_hdr.uncompressed_page_size); } @@ -523,8 +571,11 @@ void ColumnReader::BeginRead(data_ptr_t define_out, data_ptr_t repeat_out) { idx_t ColumnReader::ReadPageHeaders(idx_t max_read, optional_ptr filter, optional_ptr filter_state) { + int8_t page_ordinal = 0; while (page_rows_available == 0) { + aad_crypto_metadata.page_ordinal = page_ordinal; PrepareRead(filter, filter_state); + page_ordinal++; } return MinValue(MinValue(max_read, page_rows_available), STANDARD_VECTOR_SIZE); } @@ -895,7 +946,6 @@ unique_ptr ColumnReader::CreateReader(ParquetReader &reader, const default: throw NotImplementedException("Unrecognized Parquet type for Decimal"); } - break; case LogicalTypeId::UUID: return make_uniq(reader, schema); case LogicalTypeId::INTERVAL: diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 7cdd51bc5..cf0652ede 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -1,7 +1,7 @@ #include "column_writer.hpp" #include "duckdb.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_rle_bp_decoder.hpp" #include "parquet_bss_encoder.hpp" #include "parquet_statistics.hpp" @@ -13,6 +13,7 @@ #include "writer/list_column_writer.hpp" #include "writer/primitive_column_writer.hpp" #include "writer/struct_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include "writer/templated_column_writer.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/operator/comparison_operators.hpp" @@ -96,7 +97,7 @@ bool ColumnWriterStatistics::HasGeoStats() { return false; } -optional_ptr ColumnWriterStatistics::GetGeoStats() { +optional_ptr ColumnWriterStatistics::GetGeoStats() { return nullptr; } @@ -107,10 +108,9 @@ void ColumnWriterStatistics::WriteGeoStats(duckdb_parquet::GeospatialStatistics //===--------------------------------------------------------------------===// // ColumnWriter //===--------------------------------------------------------------------===// -ColumnWriter::ColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : writer(writer), column_schema(column_schema), schema_path(std::move(schema_path_p)), - can_have_nulls(can_have_nulls) { +ColumnWriter::ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema_p, vector schema_path_p) + : writer(writer), column_schema(std::move(column_schema_p)), schema_path(std::move(schema_path_p)) { + can_have_nulls = column_schema.repetition_type == duckdb_parquet::FieldRepetitionType::OPTIONAL; } ColumnWriter::~ColumnWriter() { } @@ -181,8 +181,7 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si } } -void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count, - idx_t max_repeat) const { +void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count) const { if (!parent) { // no repeat levels without a parent node return; @@ -244,18 +243,22 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat // Create Column Writer //===--------------------------------------------------------------------===// -ParquetColumnSchema ColumnWriter::FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) { - auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; +unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, + vector path_in_schema, const LogicalType &type, + const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat, idx_t max_define, bool can_have_nulls) { + path_in_schema.push_back(name); + if (!can_have_nulls) { max_define--; } - idx_t schema_idx = schemas.size(); + auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; optional_ptr field_id; optional_ptr child_field_ids; + optional_ptr shredding_type; if (field_ids) { auto field_id_it = field_ids->ids->find(name); if (field_id_it != field_ids->ids->end()) { @@ -263,268 +266,226 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorchild_field_ids; } } + if (shredding_types) { + shredding_type = shredding_types->GetChild(name); + } + + if (type.id() == LogicalTypeId::VARIANT) { + const bool is_shredded = shredding_type != nullptr; + + //! Build the child types for the Parquet VARIANT + child_list_t child_types; + child_types.emplace_back("metadata", LogicalType::BLOB); + child_types.emplace_back("value", LogicalType::BLOB); + if (is_shredded) { + auto &typed_value_type = shredding_type->type; + if (typed_value_type.id() != LogicalTypeId::ANY) { + child_types.emplace_back("typed_value", + VariantColumnWriter::TransformTypedValueRecursive(typed_value_type)); + } + } + + //! Construct the column schema + auto variant_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + vector> child_writers; + child_writers.reserve(child_types.size()); + + //! Then construct the child writers for the Parquet VARIANT + for (auto &entry : child_types) { + auto &child_name = entry.first; + auto &child_type = entry.second; + bool is_optional; + if (child_name == "metadata") { + is_optional = false; + } else if (child_name == "value") { + if (is_shredded) { + //! When shredding the variant, the 'value' becomes optional + is_optional = true; + } else { + is_optional = false; + } + } else { + D_ASSERT(child_name == "typed_value"); + is_optional = true; + } + + child_writers.push_back(CreateWriterRecursive(context, writer, path_in_schema, child_type, child_name, + allow_geometry, child_field_ids, shredding_type, max_repeat, + max_define + 1, is_optional)); + } + return make_uniq(writer, std::move(variant_column), path_in_schema, + std::move(child_writers)); + } if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { - auto &child_types = StructType::GetChildTypes(type); - // set up the schema element for this struct - duckdb_parquet::SchemaElement schema_element; - schema_element.repetition_type = null_type; - schema_element.num_children = UnsafeNumericCast(child_types.size()); - schema_element.__isset.num_children = true; - schema_element.__isset.type = false; - schema_element.__isset.repetition_type = true; - schema_element.name = name; + auto struct_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; + struct_column.field_id = field_id->field_id; } - schemas.push_back(std::move(schema_element)); - ParquetColumnSchema struct_column(name, type, max_define, max_repeat, schema_idx, 0); // construct the child schemas recursively - struct_column.children.reserve(child_types.size()); - for (auto &child_type : child_types) { - struct_column.children.emplace_back(FillParquetSchema(schemas, child_type.second, child_type.first, - child_field_ids, max_repeat, max_define + 1)); + auto &child_types = StructType::GetChildTypes(type); + vector> child_writers; + child_writers.reserve(child_types.size()); + for (auto &entry : child_types) { + auto &child_type = entry.second; + auto &child_name = entry.first; + child_writers.push_back(CreateWriterRecursive(context, writer, path_in_schema, child_type, child_name, + allow_geometry, child_field_ids, shredding_type, max_repeat, + max_define + 1, true)); } - return struct_column; + return make_uniq(writer, std::move(struct_column), std::move(path_in_schema), + std::move(child_writers)); } + if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::ARRAY) { auto is_list = type.id() == LogicalTypeId::LIST; auto &child_type = is_list ? ListType::GetChildType(type) : ArrayType::GetChildType(type); - // set up the two schema elements for the list - // for some reason we only set the converted type in the OPTIONAL element - // first an OPTIONAL element - duckdb_parquet::SchemaElement optional_element; - optional_element.repetition_type = null_type; - optional_element.num_children = 1; - optional_element.converted_type = ConvertedType::LIST; - optional_element.__isset.num_children = true; - optional_element.__isset.type = false; - optional_element.__isset.repetition_type = true; - optional_element.__isset.converted_type = true; - optional_element.name = name; - if (field_id && field_id->set) { - optional_element.__isset.field_id = true; - optional_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(optional_element)); - - // then a REPEATED element - duckdb_parquet::SchemaElement repeated_element; - repeated_element.repetition_type = FieldRepetitionType::REPEATED; - repeated_element.num_children = 1; - repeated_element.__isset.num_children = true; - repeated_element.__isset.type = false; - repeated_element.__isset.repetition_type = true; - repeated_element.name = "list"; - schemas.push_back(std::move(repeated_element)); - - ParquetColumnSchema list_column(name, type, max_define, max_repeat, schema_idx, 0); - list_column.children.push_back( - FillParquetSchema(schemas, child_type, "element", child_field_ids, max_repeat + 1, max_define + 2)); - return list_column; - } - if (type.id() == LogicalTypeId::MAP) { - // map type - // maps are stored as follows: - // group (MAP) { - // repeated group key_value { - // required key; - // value; - // } - // } - // top map element - duckdb_parquet::SchemaElement top_element; - top_element.repetition_type = null_type; - top_element.num_children = 1; - top_element.converted_type = ConvertedType::MAP; - top_element.__isset.repetition_type = true; - top_element.__isset.num_children = true; - top_element.__isset.converted_type = true; - top_element.__isset.type = false; - top_element.name = name; - if (field_id && field_id->set) { - top_element.__isset.field_id = true; - top_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(top_element)); - - // key_value element - duckdb_parquet::SchemaElement kv_element; - kv_element.repetition_type = FieldRepetitionType::REPEATED; - kv_element.num_children = 2; - kv_element.__isset.repetition_type = true; - kv_element.__isset.num_children = true; - kv_element.__isset.type = false; - kv_element.name = "key_value"; - schemas.push_back(std::move(kv_element)); - - // construct the child types recursively - vector kv_types {MapType::KeyType(type), MapType::ValueType(type)}; - vector kv_names {"key", "value"}; - ParquetColumnSchema map_column(name, type, max_define, max_repeat, schema_idx, 0); - map_column.children.reserve(2); - for (idx_t i = 0; i < 2; i++) { - // key needs to be marked as REQUIRED - bool is_key = i == 0; - auto child_schema = FillParquetSchema(schemas, kv_types[i], kv_names[i], child_field_ids, max_repeat + 1, - max_define + 2, !is_key); + path_in_schema.push_back("list"); + auto child_writer = + CreateWriterRecursive(context, writer, path_in_schema, child_type, "element", allow_geometry, + child_field_ids, shredding_type, max_repeat + 1, max_define + 2, true); - map_column.children.push_back(std::move(child_schema)); + auto list_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + list_column.field_id = field_id->field_id; } - return map_column; - } - duckdb_parquet::SchemaElement schema_element; - schema_element.type = ParquetWriter::DuckDBTypeToParquetType(type); - schema_element.repetition_type = null_type; - schema_element.__isset.num_children = false; - schema_element.__isset.type = true; - schema_element.__isset.repetition_type = true; - schema_element.name = name; - if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; - } - ParquetWriter::SetSchemaProperties(type, schema_element); - schemas.push_back(std::move(schema_element)); - return ParquetColumnSchema(name, type, max_define, max_repeat, schema_idx, 0); -} - -unique_ptr -ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - const vector &parquet_schemas, - const ParquetColumnSchema &schema, vector path_in_schema) { - auto &type = schema.type; - auto can_have_nulls = parquet_schemas[schema.schema_index].repetition_type == FieldRepetitionType::OPTIONAL; - path_in_schema.push_back(schema.name); - if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { - // construct the child writers recursively - vector> child_writers; - child_writers.reserve(schema.children.size()); - for (auto &child_column : schema.children) { - child_writers.push_back( - CreateWriterRecursive(context, writer, parquet_schemas, child_column, path_in_schema)); - } - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writers), - can_have_nulls); - } - if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::ARRAY) { - auto is_list = type.id() == LogicalTypeId::LIST; - path_in_schema.push_back("list"); - auto child_writer = CreateWriterRecursive(context, writer, parquet_schemas, schema.children[0], path_in_schema); if (is_list) { - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writer), - can_have_nulls); + return make_uniq(writer, std::move(list_column), std::move(path_in_schema), + std::move(child_writer)); } else { - return make_uniq(writer, schema, std::move(path_in_schema), std::move(child_writer), - can_have_nulls); + return make_uniq(writer, std::move(list_column), std::move(path_in_schema), + std::move(child_writer)); } } + if (type.id() == LogicalTypeId::MAP) { path_in_schema.push_back("key_value"); + // construct the child types recursively + child_list_t key_value; + key_value.reserve(2); + key_value.emplace_back("key", MapType::KeyType(type)); + key_value.emplace_back("value", MapType::ValueType(type)); + auto key_value_type = LogicalType::STRUCT(key_value); + + auto map_column = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + map_column.field_id = field_id->field_id; + } + vector> child_writers; child_writers.reserve(2); for (idx_t i = 0; i < 2; i++) { // key needs to be marked as REQUIRED + bool is_key = i == 0; + auto &child_name = key_value[i].first; + auto &child_type = key_value[i].second; auto child_writer = - CreateWriterRecursive(context, writer, parquet_schemas, schema.children[i], path_in_schema); + CreateWriterRecursive(context, writer, path_in_schema, child_type, child_name, allow_geometry, + child_field_ids, shredding_type, max_repeat + 1, max_define + 2, !is_key); + child_writers.push_back(std::move(child_writer)); } - auto struct_writer = - make_uniq(writer, schema, path_in_schema, std::move(child_writers), can_have_nulls); - return make_uniq(writer, schema, path_in_schema, std::move(struct_writer), can_have_nulls); + + auto key_value_schema = + ParquetColumnSchema::FromLogicalType("key_value", key_value_type, max_define + 1, max_repeat + 1, 0, + FieldRepetitionType::REPEATED, allow_geometry); + auto struct_writer = make_uniq(writer, std::move(key_value_schema), path_in_schema, + std::move(child_writers)); + return make_uniq(writer, std::move(map_column), path_in_schema, std::move(struct_writer)); } - if (type.id() == LogicalTypeId::BLOB && type.GetAlias() == "WKB_BLOB") { - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + auto schema = + ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); + if (field_id && field_id->set) { + schema.field_id = field_id->field_id; } switch (type.id()) { case LogicalTypeId::BOOLEAN: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::TINYINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::SMALLINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::INTEGER: case LogicalTypeId::DATE: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::BIGINT: case LogicalTypeId::TIME: case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: case LogicalTypeId::TIMESTAMP_MS: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::TIME_TZ: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::HUGEINT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UHUGEINT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::TIMESTAMP_NS: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::TIMESTAMP_SEC: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UTINYINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::USMALLINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::UINTEGER: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UBIGINT: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::FLOAT: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::DOUBLE: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::DECIMAL: switch (type.InternalType()) { case PhysicalType::INT16: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case PhysicalType::INT32: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case PhysicalType::INT64: - return make_uniq>(writer, schema, std::move(path_in_schema), - can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); default: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); } case LogicalTypeId::BLOB: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); + case LogicalTypeId::GEOMETRY: + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::VARCHAR: - return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq>(writer, std::move(schema), + std::move(path_in_schema)); case LogicalTypeId::UUID: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::INTERVAL: return make_uniq>( - writer, schema, std::move(path_in_schema), can_have_nulls); + writer, std::move(schema), std::move(path_in_schema)); case LogicalTypeId::ENUM: - return make_uniq(writer, schema, std::move(path_in_schema), can_have_nulls); + return make_uniq(writer, std::move(schema), std::move(path_in_schema)); default: throw InternalException("Unsupported type \"%s\" in Parquet writer", type.ToString()); } @@ -533,10 +494,10 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write template <> struct NumericLimits { static constexpr float Minimum() { - return std::numeric_limits::lowest(); + return NumericLimits::Minimum(); }; static constexpr float Maximum() { - return std::numeric_limits::max(); + return NumericLimits::Maximum(); }; static constexpr bool IsSigned() { return std::is_signed::value; @@ -549,10 +510,10 @@ struct NumericLimits { template <> struct NumericLimits { static constexpr double Minimum() { - return std::numeric_limits::lowest(); + return NumericLimits::Minimum(); }; static constexpr double Maximum() { - return std::numeric_limits::max(); + return NumericLimits::Maximum(); }; static constexpr bool IsSigned() { return std::is_signed::value; diff --git a/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp b/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp index 9a0c1eac5..a2fd7abd9 100644 --- a/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp +++ b/src/duckdb/extension/parquet/decoder/delta_length_byte_array_decoder.cpp @@ -34,13 +34,21 @@ void DeltaLengthByteArrayDecoder::InitializePage() { void DeltaLengthByteArrayDecoder::Read(shared_ptr &block_ref, uint8_t *defines, idx_t read_count, Vector &result, idx_t result_offset) { if (defines) { - ReadInternal(block_ref, defines, read_count, result, result_offset); + if (reader.Type().IsJSONType()) { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } else { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } } else { - ReadInternal(block_ref, defines, read_count, result, result_offset); + if (reader.Type().IsJSONType()) { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } else { + ReadInternal(block_ref, defines, read_count, result, result_offset); + } } } -template +template void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &block_ref, uint8_t *const defines, const idx_t read_count, Vector &result, const idx_t result_offset) { auto &block = *block_ref; @@ -58,6 +66,8 @@ void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &blo } } + const auto &string_column_reader = reader.Cast(); + const auto start_ptr = block.ptr; for (idx_t row_idx = 0; row_idx < read_count; row_idx++) { const auto result_idx = result_offset + row_idx; @@ -75,11 +85,15 @@ void DeltaLengthByteArrayDecoder::ReadInternal(shared_ptr &blo } const auto &str_len = length_data[length_idx++]; result_data[result_idx] = string_t(char_ptr_cast(block.ptr), str_len); + if (VALIDATE_INDIVIDUAL_STRINGS) { + string_column_reader.VerifyString(char_ptr_cast(block.ptr), str_len); + } block.unsafe_inc(str_len); } - // Verify that the strings we read are valid UTF-8 - reader.Cast().VerifyString(char_ptr_cast(start_ptr), block.ptr - start_ptr); + if (!VALIDATE_INDIVIDUAL_STRINGS) { + string_column_reader.VerifyString(char_ptr_cast(start_ptr), NumericCast(block.ptr - start_ptr)); + } StringColumnReader::ReferenceBlock(result, block_ref); } diff --git a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp index dfce2343b..79dc43d9c 100644 --- a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp +++ b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp @@ -14,39 +14,36 @@ DictionaryDecoder::DictionaryDecoder(ColumnReader &reader) void DictionaryDecoder::InitializeDictionary(idx_t new_dictionary_size, optional_ptr filter, optional_ptr filter_state, bool has_defines) { - auto old_dict_size = dictionary_size; dictionary_size = new_dictionary_size; filter_result.reset(); filter_count = 0; can_have_nulls = has_defines; - // we use the first value in the dictionary to keep a NULL - if (!dictionary) { - dictionary = make_uniq(reader.Type(), dictionary_size + 1); - } else if (dictionary_size > old_dict_size) { - dictionary->Resize(old_dict_size, dictionary_size + 1); - } - dictionary_id = - reader.reader.GetFileName() + "_" + reader.Schema().name + "_" + std::to_string(reader.chunk_read_offset); + // we use the last entry as a NULL, dictionary vectors don't have a separate validity mask - auto &dict_validity = FlatVector::Validity(*dictionary); - dict_validity.Reset(dictionary_size + 1); + const auto duckdb_dictionary_size = dictionary_size + can_have_nulls; + dictionary = DictionaryVector::CreateReusableDictionary(reader.Type(), duckdb_dictionary_size); + auto &dict_validity = FlatVector::Validity(dictionary->data); + dict_validity.Reset(duckdb_dictionary_size); if (can_have_nulls) { dict_validity.SetInvalid(dictionary_size); } - reader.Plain(reader.block, nullptr, dictionary_size, 0, *dictionary); + // now read the non-NULL values from Parquet + reader.Plain(reader.block, nullptr, dictionary_size, 0, dictionary->data); + + // immediately filter the dictionary, if applicable if (filter && CanFilter(*filter, *filter_state)) { // no filter result yet - apply filter to the dictionary // initialize the filter result - setting everything to false - filter_result = make_unsafe_uniq_array(dictionary_size); + filter_result = make_unsafe_uniq_array(duckdb_dictionary_size); // apply the filter UnifiedVectorFormat vdata; - dictionary->ToUnifiedFormat(dictionary_size, vdata); + dictionary->data.ToUnifiedFormat(duckdb_dictionary_size, vdata); SelectionVector dict_sel; - filter_count = dictionary_size; - ColumnSegment::FilterSelection(dict_sel, *dictionary, vdata, *filter, *filter_state, dictionary_size, - filter_count); + filter_count = duckdb_dictionary_size; + ColumnSegment::FilterSelection(dict_sel, dictionary->data, vdata, *filter, *filter_state, + duckdb_dictionary_size, filter_count); // now set all matching tuples to true for (idx_t i = 0; i < filter_count; i++) { @@ -91,13 +88,14 @@ idx_t DictionaryDecoder::GetValidValues(uint8_t *defines, idx_t read_count, idx_ } idx_t DictionaryDecoder::Read(uint8_t *defines, idx_t read_count, Vector &result, idx_t result_offset) { - if (!dictionary || dictionary_size < 0) { + if (!dictionary) { throw std::runtime_error("Parquet file is likely corrupted, missing dictionary"); } idx_t valid_count = GetValidValues(defines, read_count, result_offset); if (valid_count == read_count) { // all values are valid - we can directly decompress the offsets into the selection vector - dict_decoder->GetBatch(data_ptr_cast(dictionary_selection_vector.data()), valid_count); + dict_decoder->GetBatch(data_ptr_cast(dictionary_selection_vector.data()), + NumericCast(valid_count)); // we do still need to verify the offsets though uint32_t max_index = 0; for (idx_t idx = 0; idx < valid_count; idx++) { @@ -109,30 +107,29 @@ idx_t DictionaryDecoder::Read(uint8_t *defines, idx_t read_count, Vector &result } else if (valid_count > 0) { // for the valid entries - decode the offsets offset_buffer.resize(reader.reader.allocator, sizeof(uint32_t) * valid_count); - dict_decoder->GetBatch(offset_buffer.ptr, valid_count); + dict_decoder->GetBatch(offset_buffer.ptr, NumericCast(valid_count)); ConvertDictToSelVec(reinterpret_cast(offset_buffer.ptr), valid_sel, valid_count); } #ifdef DEBUG dictionary_selection_vector.Verify(read_count, dictionary_size + can_have_nulls); #endif if (result_offset == 0) { - result.Dictionary(*dictionary, dictionary_size + can_have_nulls, dictionary_selection_vector, read_count); - DictionaryVector::SetDictionaryId(result, dictionary_id); + result.Dictionary(dictionary, dictionary_selection_vector); D_ASSERT(result.GetVectorType() == VectorType::DICTIONARY_VECTOR); } else { D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(*dictionary, result, dictionary_selection_vector, read_count, 0, result_offset); + VectorOperations::Copy(dictionary->data, result, dictionary_selection_vector, read_count, 0, result_offset); } return valid_count; } void DictionaryDecoder::Skip(uint8_t *defines, idx_t skip_count) { - if (!dictionary || dictionary_size < 0) { + if (!dictionary) { throw std::runtime_error("Parquet file is likely corrupted, missing dictionary"); } idx_t valid_count = reader.GetValidCount(defines, skip_count); // skip past the valid offsets - dict_decoder->Skip(valid_count); + dict_decoder->Skip(NumericCast(valid_count)); } bool DictionaryDecoder::DictionarySupportsFilter(const TableFilter &filter, TableFilterState &filter_state) { @@ -193,7 +190,7 @@ bool DictionaryDecoder::CanFilter(const TableFilter &filter, TableFilterState &f void DictionaryDecoder::Filter(uint8_t *defines, const idx_t read_count, Vector &result, SelectionVector &sel, idx_t &approved_tuple_count) { - if (!dictionary || dictionary_size < 0) { + if (!dictionary) { throw std::runtime_error("Parquet file is likely corrupted, missing dictionary"); } D_ASSERT(filter_count > 0); diff --git a/src/duckdb/extension/parquet/include/column_reader.hpp b/src/duckdb/extension/parquet/include/column_reader.hpp index 79259875b..a5d9dab05 100644 --- a/src/duckdb/extension/parquet/include/column_reader.hpp +++ b/src/duckdb/extension/parquet/include/column_reader.hpp @@ -21,11 +21,13 @@ #include "decoder/delta_length_byte_array_decoder.hpp" #include "decoder/delta_byte_array_decoder.hpp" #include "parquet_column_schema.hpp" +#include "parquet_crypto.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/types/vector_cache.hpp" +#include "duckdb/common/encryption_functions.hpp" namespace duckdb { class ParquetReader; @@ -94,6 +96,20 @@ class ColumnReader { return column_schema.max_repeat; } + void InitializeCryptoMetadata(const duckdb_parquet::EncryptionAlgorithm &encryption_algorithm, + idx_t row_group_ordinal_p) { + std::string unique_file_identifier; + if (encryption_algorithm.__isset.AES_GCM_V1) { + unique_file_identifier = encryption_algorithm.AES_GCM_V1.aad_file_unique; + } else if (encryption_algorithm.__isset.AES_GCM_CTR_V1) { + throw InternalException("File is encrypted with AES_GCM_CTR_V1, but this is not supported by DuckDB"); + } else { + throw InternalException("File is encrypted but no encryption algorithm is set"); + } + + aad_crypto_metadata.Initialize(unique_file_identifier, row_group_ordinal_p, ColumnIndex()); + } + virtual idx_t FileOffset() const; virtual uint64_t TotalCompressedSize(); virtual idx_t GroupRowsAvailable(); @@ -177,6 +193,10 @@ class ColumnReader { idx_t &approved_tuple_count); void DirectSelect(uint64_t num_values, data_ptr_t define_out, data_ptr_t repeat_out, Vector &result, const SelectionVector &sel, idx_t approved_tuple_count); + void ReadEncrypted(duckdb_apache::thrift::TBase &object); + void ReadDataEncrypted(const data_ptr_t buffer, const uint32_t buffer_size, PageType::type module); + void Read(PageHeader &page_hdr); + void ReadData(const data_ptr_t buffer, const uint32_t buffer_size, PageType::type page_type); private: //! Check if a previous table filter has filtered out this page @@ -315,6 +335,7 @@ class ColumnReader { DeltaLengthByteArrayDecoder delta_length_byte_array_decoder; DeltaByteArrayDecoder delta_byte_array_decoder; ByteStreamSplitDecoder byte_stream_split_decoder; + CryptoMetaData aad_crypto_metadata; //! Resizeable buffers used for the various encodings above ResizeableBuffer encoding_buffers[2]; diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index d475e903b..1463137ad 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #include "parquet_types.h" #include "parquet_column_schema.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { class MemoryStream; @@ -18,6 +19,7 @@ class ParquetWriter; class ColumnWriterPageState; class PrimitiveColumnWriterState; struct ChildFieldIDs; +struct ShreddingType; class ResizeableBuffer; class ParquetBloomFilter; @@ -62,20 +64,34 @@ class ColumnWriterPageState { } }; +struct ParquetAnalyzeSchemaState { +public: + ParquetAnalyzeSchemaState() { + } + virtual ~ParquetAnalyzeSchemaState() { + } + +public: + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } +}; + class ColumnWriter { protected: static constexpr uint16_t PARQUET_DEFINE_VALID = UINT16_C(65535); public: - ColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path, - bool can_have_nulls); + ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); virtual ~ColumnWriter(); - ParquetWriter &writer; - const ParquetColumnSchema &column_schema; - vector schema_path; - bool can_have_nulls; - public: const LogicalType &Type() const { return column_schema.type; @@ -83,8 +99,12 @@ class ColumnWriter { const ParquetColumnSchema &Schema() const { return column_schema; } + ParquetColumnSchema &Schema() { + return column_schema; + } inline idx_t SchemaIndex() const { - return column_schema.schema_index; + D_ASSERT(column_schema.schema_index.IsValid()); + return column_schema.schema_index.GetIndex(); } inline idx_t MaxDefine() const { return column_schema.max_define; @@ -92,16 +112,49 @@ class ColumnWriter { idx_t MaxRepeat() const { return column_schema.max_repeat; } + virtual bool HasTransform() { + for (auto &child_writer : child_writers) { + if (child_writer->HasTransform()) { + throw NotImplementedException("ColumnWriter of type '%s' requires a transform, but is not a root " + "column, this isn't supported currently", + child_writer->Type()); + } + } + return false; + } + virtual LogicalType TransformedType() { + throw NotImplementedException("Writer does not have a transformed type"); + } + virtual unique_ptr TransformExpression(unique_ptr expr) { + throw NotImplementedException("Writer does not have a transform expression"); + } + + virtual unique_ptr AnalyzeSchemaInit() { + return nullptr; + } + + const vector> &ChildWriters() const { + return child_writers; + } + + virtual void AnalyzeSchema(ParquetAnalyzeSchemaState &state, Vector &input, idx_t count) { + throw NotImplementedException("Writer doesn't require an AnalyzeSchema pass"); + } + + virtual void AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state) { + throw NotImplementedException("Writer doesn't require an AnalyzeSchemaFinalize pass"); + } + + virtual void FinalizeSchema(vector &schemas) = 0; - static ParquetColumnSchema FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat = 0, - idx_t max_define = 1, bool can_have_nulls = true); //! Create the column writer for a specific type recursively static unique_ptr CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - const vector &parquet_schemas, - const ParquetColumnSchema &schema, - vector path_in_schema); + vector path_in_schema, const LogicalType &type, + const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat = 0, idx_t max_define = 1, + bool can_have_nulls = true); virtual unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) = 0; @@ -129,10 +182,19 @@ class ColumnWriter { protected: void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, const idx_t count, const uint16_t define_value, const uint16_t null_value) const; - void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat) const; + void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count) const; void CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, AllocatedData &compressed_buf); + +public: + ParquetWriter &writer; + ParquetColumnSchema column_schema; + vector schema_path; + bool can_have_nulls; + +protected: + vector> child_writers; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp b/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp index f8141e26e..9f304da25 100644 --- a/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp +++ b/src/duckdb/extension/parquet/include/decoder/delta_length_byte_array_decoder.hpp @@ -27,7 +27,7 @@ class DeltaLengthByteArrayDecoder { void Skip(uint8_t *defines, idx_t skip_count); private: - template + template void ReadInternal(shared_ptr &block, uint8_t *defines, idx_t read_count, Vector &result, idx_t result_offset); template diff --git a/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp b/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp index c012a82dc..de75b045e 100644 --- a/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/decoder/dictionary_decoder.hpp @@ -47,11 +47,10 @@ class DictionaryDecoder { SelectionVector valid_sel; SelectionVector dictionary_selection_vector; idx_t dictionary_size; - unique_ptr dictionary; + buffer_ptr dictionary; unsafe_unique_array filter_result; idx_t filter_count; bool can_have_nulls; - string dictionary_id; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp deleted file mode 100644 index 6dc82bc8d..000000000 --- a/src/duckdb/extension/parquet/include/geo_parquet.hpp +++ /dev/null @@ -1,241 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// geo_parquet.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_writer.hpp" -#include "duckdb/common/string.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "parquet_types.h" - -namespace duckdb { - -struct ParquetColumnSchema; - -struct GeometryKindSet { - - uint8_t bits[4] = {0, 0, 0, 0}; - - void Add(uint32_t wkb_type) { - auto kind = wkb_type % 1000; - auto dims = wkb_type / 1000; - if (kind < 1 || kind > 7 || (dims) > 3) { - return; - } - bits[dims] |= (1 << (kind - 1)); - } - - void Combine(const GeometryKindSet &other) { - for (uint32_t d = 0; d < 4; d++) { - bits[d] |= other.bits[d]; - } - } - - bool IsEmpty() const { - for (uint32_t d = 0; d < 4; d++) { - if (bits[d] != 0) { - return false; - } - } - return true; - } - - template - vector ToList() const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - result.push_back(i + d * 1000); - } - } - } - return result; - } - - vector ToString(bool snake_case) const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - string str; - switch (i) { - case 1: - str = snake_case ? "point" : "Point"; - break; - case 2: - str = snake_case ? "linestring" : "LineString"; - break; - case 3: - str = snake_case ? "polygon" : "Polygon"; - break; - case 4: - str = snake_case ? "multipoint" : "MultiPoint"; - break; - case 5: - str = snake_case ? "multilinestring" : "MultiLineString"; - break; - case 6: - str = snake_case ? "multipolygon" : "MultiPolygon"; - break; - case 7: - str = snake_case ? "geometrycollection" : "GeometryCollection"; - break; - default: - str = snake_case ? "unknown" : "Unknown"; - break; - } - switch (d) { - case 1: - str += snake_case ? "_z" : " Z"; - break; - case 2: - str += snake_case ? "_m" : " M"; - break; - case 3: - str += snake_case ? "_zm" : " ZM"; - break; - default: - break; - } - - result.push_back(str); - } - } - } - return result; - } -}; - -struct GeometryExtent { - - double xmin = NumericLimits::Maximum(); - double xmax = NumericLimits::Minimum(); - double ymin = NumericLimits::Maximum(); - double ymax = NumericLimits::Minimum(); - double zmin = NumericLimits::Maximum(); - double zmax = NumericLimits::Minimum(); - double mmin = NumericLimits::Maximum(); - double mmax = NumericLimits::Minimum(); - - bool IsSet() const { - return xmin != NumericLimits::Maximum() && xmax != NumericLimits::Minimum() && - ymin != NumericLimits::Maximum() && ymax != NumericLimits::Minimum(); - } - - bool HasZ() const { - return zmin != NumericLimits::Maximum() && zmax != NumericLimits::Minimum(); - } - - bool HasM() const { - return mmin != NumericLimits::Maximum() && mmax != NumericLimits::Minimum(); - } - - void Combine(const GeometryExtent &other) { - xmin = std::min(xmin, other.xmin); - xmax = std::max(xmax, other.xmax); - ymin = std::min(ymin, other.ymin); - ymax = std::max(ymax, other.ymax); - zmin = std::min(zmin, other.zmin); - zmax = std::max(zmax, other.zmax); - mmin = std::min(mmin, other.mmin); - mmax = std::max(mmax, other.mmax); - } - - void Combine(const double &xmin_p, const double &xmax_p, const double &ymin_p, const double &ymax_p) { - xmin = std::min(xmin, xmin_p); - xmax = std::max(xmax, xmax_p); - ymin = std::min(ymin, ymin_p); - ymax = std::max(ymax, ymax_p); - } - - void ExtendX(const double &x) { - xmin = std::min(xmin, x); - xmax = std::max(xmax, x); - } - void ExtendY(const double &y) { - ymin = std::min(ymin, y); - ymax = std::max(ymax, y); - } - void ExtendZ(const double &z) { - zmin = std::min(zmin, z); - zmax = std::max(zmax, z); - } - void ExtendM(const double &m) { - mmin = std::min(mmin, m); - mmax = std::max(mmax, m); - } -}; - -struct GeometryStats { - GeometryKindSet types; - GeometryExtent bbox; - - void Update(const string_t &wkb); -}; - -//------------------------------------------------------------------------------ -// GeoParquetMetadata -//------------------------------------------------------------------------------ -class ParquetReader; -class ColumnReader; -class ClientContext; -class ExpressionExecutor; - -enum class GeoParquetColumnEncoding : uint8_t { - WKB = 1, - POINT, - LINESTRING, - POLYGON, - MULTIPOINT, - MULTILINESTRING, - MULTIPOLYGON, -}; - -struct GeoParquetColumnMetadata { - // The encoding of the geometry column - GeoParquetColumnEncoding geometry_encoding; - - // The statistics of the geometry column - GeometryStats stats; - - // The crs of the geometry column (if any) in PROJJSON format - string projjson; - - // Used to track the "primary" geometry column (if any) - idx_t insertion_index = 0; -}; - -class GeoParquetFileMetadata { -public: - void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStats &stats); - void Write(duckdb_parquet::FileMetaData &file_meta_data); - - // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not - // available. - static unique_ptr TryRead(const duckdb_parquet::FileMetaData &file_meta_data, - const ClientContext &context); - const unordered_map &GetColumnMeta() const; - - static unique_ptr CreateColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema, - ClientContext &context); - - bool IsGeometryColumn(const string &column_name) const; - - static bool IsGeoParquetConversionEnabled(const ClientContext &context); - static LogicalType GeometryType(); - -private: - mutex write_lock; - string version = "1.1.0"; - unordered_map geometry_columns; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp index d467e2a02..cc6cfb706 100644 --- a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp +++ b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp @@ -12,10 +12,16 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; +using duckdb_parquet::SchemaElement; + using duckdb_parquet::FileMetaData; struct ParquetOptions; -enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, GEOMETRY, EXPRESSION, VARIANT }; +enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, EXPRESSION, VARIANT, GEOMETRY }; enum class ParquetExtraTypeInfo { NONE, @@ -30,29 +36,60 @@ enum class ParquetExtraTypeInfo { }; struct ParquetColumnSchema { +public: ParquetColumnSchema() = default; - ParquetColumnSchema(idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t file_index, - ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); - ParquetColumnSchema(string name, LogicalType type, idx_t max_define, idx_t max_repeat, idx_t schema_index, - idx_t column_index, ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); - ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, ParquetColumnSchemaType schema_type); + ParquetColumnSchema(ParquetColumnSchema &&other) = default; + ParquetColumnSchema(const ParquetColumnSchema &other) = default; + ParquetColumnSchema &operator=(ParquetColumnSchema &&other) = default; - ParquetColumnSchemaType schema_type; +public: + //! Writer constructors + static ParquetColumnSchema FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t column_index, + duckdb_parquet::FieldRepetitionType::type repetition_type, + bool allow_geometry, + ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); + +public: + //! Reader constructors + static ParquetColumnSchema FromSchemaElement(const SchemaElement &element, idx_t max_define, idx_t max_repeat, + idx_t schema_index, idx_t column_index, ParquetColumnSchemaType type, + const ParquetOptions &options); + static ParquetColumnSchema FromParentSchema(ParquetColumnSchema parent, LogicalType result_type, + ParquetColumnSchemaType schema_type); + static ParquetColumnSchema FromChildSchemas(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t schema_index, idx_t column_index, + vector &&children, + ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); + static ParquetColumnSchema FileRowNumber(); + +public: + unique_ptr Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, + idx_t row_group_idx_p, const vector &columns) const; + +public: + void SetSchemaIndex(idx_t schema_idx); + +public: string name; - LogicalType type; idx_t max_define; idx_t max_repeat; - idx_t schema_index; + //! Populated by FinalizeSchema if used in the parquet_writer path + optional_idx schema_index; idx_t column_index; + ParquetColumnSchemaType schema_type; + LogicalType type; optional_idx parent_schema_index; uint32_t type_length = 0; uint32_t type_scale = 0; duckdb_parquet::Type::type parquet_type = duckdb_parquet::Type::INT32; ParquetExtraTypeInfo type_info = ParquetExtraTypeInfo::NONE; vector children; - - unique_ptr Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, - idx_t row_group_idx_p, const vector &columns) const; + optional_idx field_id; + //! Whether a column is nullable or not + duckdb_parquet::FieldRepetitionType::type repetition_type = duckdb_parquet::FieldRepetitionType::OPTIONAL; + //! Whether the column can be recognized as a GEOMETRY type + bool allow_geometry = false; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_crypto.hpp b/src/duckdb/extension/parquet/include/parquet_crypto.hpp index 7261a3bc8..2189f00f8 100644 --- a/src/duckdb/extension/parquet/include/parquet_crypto.hpp +++ b/src/duckdb/extension/parquet/include/parquet_crypto.hpp @@ -9,14 +9,19 @@ #pragma once #include "parquet_types.h" +#include "duckdb/common/allocator.hpp" #include "duckdb/common/encryption_state.hpp" +#include "duckdb/common/encryption_functions.hpp" #include "duckdb/storage/object_cache.hpp" namespace duckdb { +class ParquetAdditionalAuthenticatedData; using duckdb_apache::thrift::TBase; using duckdb_apache::thrift::protocol::TProtocol; - +using duckdb_parquet::ColumnChunk; +using duckdb_parquet::PageType; +class Allocator; class BufferedFileWriter; class ParquetKeys : public ObjectCacheEntry { @@ -36,6 +41,43 @@ class ParquetKeys : public ObjectCacheEntry { unordered_map keys; }; +struct CryptoMetaData { + CryptoMetaData(Allocator &allocator); + void Initialize(const std::string &unique_file_identifier_p, int16_t row_group_ordinal = -1, + int16_t column_ordinal = -1, int8_t module = -1, int16_t page_ordinal = -1); + void ClearAdditionalAuthenticatedData(); + void SetModule(int8_t module_p); + bool IsEmpty() const; + +public: + string unique_file_identifier = ""; + int8_t module; + int16_t row_group_ordinal; + int16_t column_ordinal; + int16_t page_ordinal; + +public: + unique_ptr additional_authenticated_data; +}; + +class ParquetAdditionalAuthenticatedData : public AdditionalAuthenticatedData { +public: + explicit ParquetAdditionalAuthenticatedData(Allocator &allocator); + ~ParquetAdditionalAuthenticatedData() override; + +public: + idx_t GetPrefixSize() const; + void Rewind() const; + void WriteParquetAAD(const CryptoMetaData &crypto_meta_data); + +private: + void WritePrefix(const std::string &prefix); + void WriteSuffix(const CryptoMetaData &crypto_meta_data); + +private: + optional_idx additional_authenticated_data_prefix_size; +}; + class ParquetEncryptionConfig { public: explicit ParquetEncryptionConfig(); @@ -68,15 +110,33 @@ class ParquetCrypto { static constexpr idx_t CRYPTO_BLOCK_SIZE = 4096; static constexpr idx_t BLOCK_SIZE = 16; + // Module types for encryption + static constexpr int8_t FOOTER = 0; + static constexpr int8_t COLUMN_METADATA = 1; + static constexpr int8_t DATA_PAGE = 2; + static constexpr int8_t DICTIONARY_PAGE = 3; + static constexpr int8_t DATA_PAGE_HEADER = 4; + static constexpr int8_t DICTIONARY_PAGE_HEADER = 5; + static constexpr int8_t COLUMN_INDEX = 6; + static constexpr int8_t OFFSET_INDEX = 7; + static constexpr int8_t BLOOM_FILTER_HEADER = 8; + static constexpr int8_t BLOOM_FILTER_BITSET = 9; + + // Standard AAD length for file + static constexpr int32_t UNIQUE_FILE_ID_LEN = 8; + // Maximum Parquet AAD suffix bytes + static constexpr int32_t AAD_MAX_SUFFIX_BYTES = 7; + public: //! Decrypt and read a Thrift object from the transport protocol - static uint32_t Read(TBase &object, TProtocol &iprot, const string &key, const EncryptionUtil &encryption_util_p); + static uint32_t Read(TBase &object, TProtocol &iprot, const string &key, const EncryptionUtil &encryption_util_p, + const CryptoMetaData &crypto_meta_data); //! Encrypt and write a Thrift object to the transport protocol static uint32_t Write(const TBase &object, TProtocol &oprot, const string &key, const EncryptionUtil &encryption_util_p); //! Decrypt and read a buffer static uint32_t ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key, - const EncryptionUtil &encryption_util_p); + const EncryptionUtil &encryption_util_p, const CryptoMetaData &crypto_meta_data); //! Encrypt and write a buffer to a file static uint32_t WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, const string &key, const EncryptionUtil &encryption_util_p); @@ -84,6 +144,14 @@ class ParquetCrypto { public: static void AddKey(ClientContext &context, const FunctionParameters ¶meters); static bool ValidKey(const std::string &key); + +public: + static int8_t GetModuleHeader(const ColumnChunk &chunk, uint16_t page_ordinal); + static int8_t GetModule(const ColumnChunk &chunk, PageType::type page_type, uint16_t page_ordinal); + static int16_t GetFinalPageOrdinal(const ColumnChunk &chunk, uint8_t module, uint16_t page_ordinal); + static void GenerateAdditionalAuthenticatedData(Allocator &allocator, CryptoMetaData &aad_crypto_metadata); + static unique_ptr GenerateFooterAAD(Allocator &allocator, + const std::string &unique_file_identifier); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp index 31fb26cc9..775160215 100644 --- a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp +++ b/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp @@ -18,7 +18,7 @@ class DbpDecoder { : buffer_(buffer, buffer_len), // block_size_in_values(ParquetDecodeUtils::VarintDecode(buffer_)), - number_of_miniblocks_per_block(ParquetDecodeUtils::VarintDecode(buffer_)), + number_of_miniblocks_per_block(DecodeNumberOfMiniblocksPerBlock(buffer_)), number_of_values_in_a_miniblock(block_size_in_values / number_of_miniblocks_per_block), total_value_count(ParquetDecodeUtils::VarintDecode(buffer_)), previous_value(ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_))), @@ -31,7 +31,7 @@ class DbpDecoder { number_of_values_in_a_miniblock % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0)) { throw InvalidInputException("Parquet file has invalid block sizes for DELTA_BINARY_PACKED"); } - }; + } ByteBuffer BufferPtr() const { return buffer_; @@ -68,6 +68,15 @@ class DbpDecoder { } private: + static idx_t DecodeNumberOfMiniblocksPerBlock(ByteBuffer &buffer) { + auto res = ParquetDecodeUtils::VarintDecode(buffer); + if (res == 0) { + throw InvalidInputException( + "Parquet file has invalid number of miniblocks per block for DELTA_BINARY_PACKED"); + } + return res; + } + template void GetBatchInternal(const data_ptr_t target_values_ptr, const idx_t batch_size) { if (batch_size == 0) { diff --git a/src/duckdb/extension/parquet/include/parquet_field_id.hpp b/src/duckdb/extension/parquet/include/parquet_field_id.hpp new file mode 100644 index 000000000..9d5dd754c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_field_id.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct FieldID; +struct ChildFieldIDs { + ChildFieldIDs(); + ChildFieldIDs Copy() const; + unique_ptr> ids; + + void Serialize(Serializer &serializer) const; + static ChildFieldIDs Deserialize(Deserializer &source); +}; + +struct FieldID { +public: + static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; + FieldID(); + explicit FieldID(int32_t field_id); + FieldID Copy() const; + bool set; + int32_t field_id; + ChildFieldIDs child_field_ids; + + void Serialize(Serializer &serializer) const; + static FieldID Deserialize(Deserializer &source); + +public: + static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types); + static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp index aa1c1c9b5..53a00186f 100644 --- a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp +++ b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp @@ -9,18 +9,20 @@ #include "duckdb.hpp" #include "duckdb/storage/object_cache.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_types.h" namespace duckdb { struct CachingFileHandle; +using duckdb_parquet::FileCryptoMetaData; enum class ParquetCacheValidity { VALID, INVALID, UNKNOWN }; class ParquetFileMetadataCache : public ObjectCacheEntry { public: ParquetFileMetadataCache(unique_ptr file_metadata, CachingFileHandle &handle, - unique_ptr geo_metadata, idx_t footer_size); + unique_ptr geo_metadata, + unique_ptr crypto_metadata, idx_t footer_size); ~ParquetFileMetadataCache() override = default; //! Parquet file metadata @@ -29,6 +31,9 @@ class ParquetFileMetadataCache : public ObjectCacheEntry { //! GeoParquet metadata unique_ptr geo_metadata; + //! Crypto metadata + unique_ptr crypto_metadata; + //! Parquet footer size idx_t footer_size; diff --git a/src/duckdb/extension/parquet/include/parquet_geometry.hpp b/src/duckdb/extension/parquet/include/parquet_geometry.hpp new file mode 100644 index 000000000..3c367ee37 --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_geometry.hpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// geo_parquet.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "column_writer.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "parquet_types.h" + +namespace duckdb { + +struct ParquetColumnSchema; +class ParquetReader; +class ColumnReader; +class ClientContext; + +struct GeometryColumnReader { + static unique_ptr Create(ParquetReader &reader, const ParquetColumnSchema &schema, + ClientContext &context); +}; + +enum class GeoParquetColumnEncoding : uint8_t { + WKB = 1, + POINT, + LINESTRING, + POLYGON, + MULTIPOINT, + MULTILINESTRING, + MULTIPOLYGON, +}; + +enum class GeoParquetVersion : uint8_t { + // Write GeoParquet 1.0 metadata + // GeoParquet 1.0 has the widest support among readers and writers + V1, + + // Write GeoParquet 2.0 + // The GeoParquet 2.0 options is identical to GeoParquet 1.0 except the underlying storage + // of spatial columns is Parquet native geometry, where the Parquet writer will include + // native statistics according to the underlying Parquet options. Compared to 'BOTH', this will + // actually write the metadata as containing GeoParquet version 2.0.0 + // However, V2 isnt standardized yet, so this option is still a bit experimental + V2, + + // Write GeoParquet 1.0 metadata, with native Parquet geometry types + // This is a bit of a hold-over option for compatibility with systems that + // reject GeoParquet 2.0 metadata, but can read Parquet native geometry types as they simply ignore the extra + // logical type. DuckDB v1.4.0 falls into this category. + BOTH, + + // Do not write GeoParquet metadata + // This option suppresses GeoParquet metadata; however, spatial types will be written as + // Parquet native Geometry/Geography. + NONE, +}; + +struct GeoParquetColumnMetadata { + // The encoding of the geometry column + GeoParquetColumnEncoding geometry_encoding; + + // The statistics of the geometry column + GeometryStatsData stats; + + // The crs of the geometry column (if any) in PROJJSON format + string projjson; + + // Used to track the "primary" geometry column (if any) + idx_t insertion_index = 0; + + GeoParquetColumnMetadata() { + geometry_encoding = GeoParquetColumnEncoding::WKB; + stats.SetEmpty(); + } +}; + +class GeoParquetFileMetadata { +public: + explicit GeoParquetFileMetadata(GeoParquetVersion geo_parquet_version) : version(geo_parquet_version) { + } + void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStatsData &stats); + void Write(duckdb_parquet::FileMetaData &file_meta_data); + + // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not + // available. + static unique_ptr TryRead(const duckdb_parquet::FileMetaData &file_meta_data, + const ClientContext &context); + const unordered_map &GetColumnMeta() const; + + bool IsGeometryColumn(const string &column_name) const; + + static bool IsGeoParquetConversionEnabled(const ClientContext &context); + +private: + mutex write_lock; + unordered_map geometry_columns; + GeoParquetVersion version; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_metadata.hpp b/src/duckdb/extension/parquet/include/parquet_metadata.hpp index 09ecd5afa..fb0900610 100644 --- a/src/duckdb/extension/parquet/include/parquet_metadata.hpp +++ b/src/duckdb/extension/parquet/include/parquet_metadata.hpp @@ -38,4 +38,9 @@ class ParquetBloomProbeFunction : public TableFunction { ParquetBloomProbeFunction(); }; +class ParquetFullMetadataFunction : public TableFunction { +public: + ParquetFullMetadataFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp index de905c70c..c8fde1eed 100644 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ b/src/duckdb/extension/parquet/include/parquet_reader.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #include "duckdb/storage/caching_file_system.hpp" #include "duckdb/common/common.hpp" +#include "duckdb/common/encryption_functions.hpp" #include "duckdb/common/encryption_state.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/multi_file/base_file_reader.hpp" @@ -105,7 +106,6 @@ struct ParquetOptions { explicit ParquetOptions(ClientContext &context); bool binary_as_string = false; - bool variant_legacy_encoding = false; bool file_row_number = false; shared_ptr encryption_config; bool debug_use_openssl = true; @@ -166,23 +166,28 @@ class ParquetReader : public BaseFileReader { bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; public: void InitializeScan(ClientContext &context, ParquetReaderScanState &state, vector groups_to_read); - void Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); + AsyncResult Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); idx_t NumRows() const; idx_t NumRowGroups() const; const duckdb_parquet::FileMetaData *GetFileMetadata() const; + string static GetUniqueFileIdentifier(const duckdb_parquet::EncryptionAlgorithm &encryption_algorithm); uint32_t Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot); + uint32_t ReadEncrypted(duckdb_apache::thrift::TBase &object, TProtocol &iprot, + CryptoMetaData &aad_crypto_metadata) const; uint32_t ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size); + uint32_t ReadDataEncrypted(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, + const uint32_t buffer_size, CryptoMetaData &aad_crypto_metadata) const; unique_ptr ReadStatistics(const string &name); @@ -195,6 +200,8 @@ class ParquetReader : public BaseFileReader { static unique_ptr ReadStatistics(const ParquetUnionData &union_data, const string &name); LogicalType DeriveLogicalType(const SchemaElement &s_ele, ParquetColumnSchema &schema) const; + static LogicalType DeriveLogicalType(const SchemaElement &s_ele, const ParquetOptions &options, + ParquetColumnSchema &schema); void AddVirtualColumn(column_t virtual_column_id) override; @@ -209,7 +216,6 @@ class ParquetReader : public BaseFileReader { shared_ptr metadata); void InitializeSchema(ClientContext &context); - bool ScanInternal(ClientContext &context, ParquetReaderScanState &state, DataChunk &output); //! Parse the schema of the file unique_ptr ParseSchema(ClientContext &context); ParquetColumnSchema ParseSchemaRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, idx_t &next_schema_idx, @@ -231,6 +237,8 @@ class ParquetReader : public BaseFileReader { MultiFileColumnDefinition ParseColumnDefinition(const duckdb_parquet::FileMetaData &file_meta_data, ParquetColumnSchema &element); + unique_ptr GenerateAAD(uint8_t module_type, uint16_t row_group_ordinal, + uint16_t column_ordinal, uint16_t page_ordinal) const; private: unique_ptr file_handle; diff --git a/src/duckdb/extension/parquet/include/parquet_shredding.hpp b/src/duckdb/extension/parquet/include/parquet_shredding.hpp new file mode 100644 index 000000000..f43cbc42c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_shredding.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types/variant.hpp" + +namespace duckdb { + +struct ShreddingType; + +struct ChildShreddingTypes { +public: + ChildShreddingTypes(); + +public: + ChildShreddingTypes Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ChildShreddingTypes Deserialize(Deserializer &source); + +public: + unique_ptr> types; +}; + +struct ShreddingType { +public: + ShreddingType(); + explicit ShreddingType(const LogicalType &type); + +public: + ShreddingType Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ShreddingType Deserialize(Deserializer &source); + +public: + static ShreddingType GetShreddingTypes(const Value &val); + void AddChild(const string &name, ShreddingType &&child); + optional_ptr GetChild(const string &name) const; + +public: + bool set = false; + LogicalType type; + ChildShreddingTypes children; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_statistics.hpp b/src/duckdb/extension/parquet/include/parquet_statistics.hpp index cb05dae3b..e138d9763 100644 --- a/src/duckdb/extension/parquet/include/parquet_statistics.hpp +++ b/src/duckdb/extension/parquet/include/parquet_statistics.hpp @@ -23,7 +23,6 @@ struct ParquetColumnSchema; class ResizeableBuffer; struct ParquetStatisticsUtils { - static unique_ptr TransformColumnStatistics(const ParquetColumnSchema &reader, const vector &columns, bool can_have_nan); diff --git a/src/duckdb/extension/parquet/include/parquet_support.hpp b/src/duckdb/extension/parquet/include/parquet_support.hpp index 91c43fcb4..0b00e6242 100644 --- a/src/duckdb/extension/parquet/include/parquet_support.hpp +++ b/src/duckdb/extension/parquet/include/parquet_support.hpp @@ -118,7 +118,6 @@ class StripeStreams { }; class ColumnReader { - public: ColumnReader(const EncodingKey &ek, StripeStreams &stripe); diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index a2bfc3a80..f28b07bbe 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -21,8 +21,10 @@ #include "parquet_statistics.hpp" #include "column_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" #include "parquet_types.h" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "writer/parquet_write_stats.hpp" #include "thrift/protocol/TCompactProtocol.h" @@ -43,29 +45,6 @@ struct PreparedRowGroup { vector> states; }; -struct FieldID; -struct ChildFieldIDs { - ChildFieldIDs(); - ChildFieldIDs Copy() const; - unique_ptr> ids; - - void Serialize(Serializer &serializer) const; - static ChildFieldIDs Deserialize(Deserializer &source); -}; - -struct FieldID { - static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; - FieldID(); - explicit FieldID(int32_t field_id); - FieldID Copy() const; - bool set; - int32_t field_id; - ChildFieldIDs child_field_ids; - - void Serialize(Serializer &serializer) const; - static FieldID Deserialize(Deserializer &source); -}; - struct ParquetBloomFilterEntry { unique_ptr bloom_filter; idx_t row_group_idx; @@ -77,25 +56,74 @@ enum class ParquetVersion : uint8_t { V2 = 2, //! Includes the encodings above }; +class ParquetWriteTransformData { +public: + ParquetWriteTransformData(ClientContext &context, vector types, + vector> expressions); + +public: + ColumnDataCollection &ApplyTransform(ColumnDataCollection &input); + +private: + //! The buffer to store the transformed chunks of a rowgroup + ColumnDataCollection buffer; + //! The expression(s) to apply to the input chunk + vector> expressions; + //! The expression executor used to transform the input chunk + ExpressionExecutor executor; + //! The intermediate chunk to target the transform to + DataChunk chunk; +}; + +struct ParquetWriteLocalState : public LocalFunctionData { +public: + explicit ParquetWriteLocalState(ClientContext &context, const vector &types); + +public: + ColumnDataCollection buffer; + ColumnDataAppendState append_state; + //! If any of the column writers require a transformation to a different shape, this will be initialized and used + unique_ptr transform_data; +}; + +struct ParquetWriteGlobalState : public GlobalFunctionData { +public: + ParquetWriteGlobalState() { + } + +public: + void LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason); + +public: + unique_ptr writer; + optional_ptr op; + mutex lock; + unique_ptr combine_buffer; + //! If any of the column writers require a transformation to a different shape, this will be initialized and used + unique_ptr transform_data; +}; + class ParquetWriter { public: ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, vector names, duckdb_parquet::CompressionCodec::type codec, ChildFieldIDs field_ids, - const vector> &kv_metadata, + ShreddingType shredding_types, const vector> &kv_metadata, shared_ptr encryption_config, optional_idx dictionary_size_limit, idx_t string_dictionary_page_size_limit, bool enable_bloom_filters, double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl, - ParquetVersion parquet_version); + ParquetVersion parquet_version, GeoParquetVersion geoparquet_version); ~ParquetWriter(); public: - void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result); + void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result, + unique_ptr &transform_data); void FlushRowGroup(PreparedRowGroup &row_group); - void Flush(ColumnDataCollection &buffer); + void Flush(ColumnDataCollection &buffer, unique_ptr &transform_data); void Finalize(); static duckdb_parquet::Type::type DuckDBTypeToParquetType(const LogicalType &duckdb_type); - static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele); + static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry); ClientContext &GetContext() { return context; @@ -139,9 +167,13 @@ class ParquetWriter { ParquetVersion GetParquetVersion() const { return parquet_version; } + GeoParquetVersion GetGeoParquetVersion() const { + return geoparquet_version; + } const string &GetFileName() const { return file_name; } + void AnalyzeSchema(ColumnDataCollection &buffer, vector> &column_writers); uint32_t Write(const duckdb_apache::thrift::TBase &object); uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size); @@ -155,6 +187,8 @@ class ParquetWriter { void SetWrittenStatistics(CopyFunctionFileStatistics &written_stats); void FlushColumnStats(idx_t col_idx, duckdb_parquet::ColumnChunk &chunk, optional_ptr writer_stats); + void InitializePreprocessing(unique_ptr &transform_data); + void InitializeSchemaElements(); private: void GatherWrittenStatistics(); @@ -166,6 +200,7 @@ class ParquetWriter { vector column_names; duckdb_parquet::CompressionCodec::type codec; ChildFieldIDs field_ids; + ShreddingType shredding_types; shared_ptr encryption_config; optional_idx dictionary_size_limit; idx_t string_dictionary_page_size_limit; @@ -175,7 +210,7 @@ class ParquetWriter { bool debug_use_openssl; shared_ptr encryption_util; ParquetVersion parquet_version; - vector column_schemas; + GeoParquetVersion geoparquet_version; unique_ptr writer; //! Atomics to reduce contention when rotating writes to multiple Parquet files diff --git a/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp index 1ead9cf04..0f93bf9d5 100644 --- a/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/interval_column_reader.hpp @@ -57,7 +57,6 @@ struct IntervalValueConversion { }; class IntervalColumnReader : public TemplatedColumnReader { - public: IntervalColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) : TemplatedColumnReader(reader, schema) { diff --git a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp index 4bc19516a..d0d18b80c 100644 --- a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp @@ -14,16 +14,30 @@ namespace duckdb { class StringColumnReader : public ColumnReader { +public: + enum class StringColumnType : uint8_t { VARCHAR, JSON, OTHER }; + + static StringColumnType GetStringColumnType(const LogicalType &type) { + if (type.IsJSONType()) { + return StringColumnType::JSON; + } + if (type.id() == LogicalTypeId::VARCHAR) { + return StringColumnType::VARCHAR; + } + return StringColumnType::OTHER; + } + public: static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; public: StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema); idx_t fixed_width_string_length; + const StringColumnType string_column_type; public: static void VerifyString(const char *str_data, uint32_t str_len, const bool isVarchar); - void VerifyString(const char *str_data, uint32_t str_len); + void VerifyString(const char *str_data, uint32_t str_len) const; static void ReferenceBlock(Vector &result, shared_ptr &block); diff --git a/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp index 3bd0e96d6..b6bd55cc7 100644 --- a/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/templated_column_reader.hpp @@ -79,7 +79,6 @@ class TemplatedColumnReader : public ColumnReader { template struct CallbackParquetValueConversion { - template static DUCKDB_PHYSICAL_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { if (CHECKED) { diff --git a/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp index 86193d9a6..22d468d0f 100644 --- a/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/uuid_column_reader.hpp @@ -50,7 +50,6 @@ struct UUIDValueConversion { }; class UUIDColumnReader : public TemplatedColumnReader { - public: UUIDColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) : TemplatedColumnReader(reader, schema) { diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp index a7c717709..0f5d4e91c 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp @@ -2,7 +2,8 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/value.hpp" -#include "reader/variant/variant_value.hpp" +#include "duckdb/common/types/variant_value.hpp" +#include "yyjson.hpp" using namespace duckdb_yyjson; @@ -137,10 +138,8 @@ class VariantBinaryDecoder { static VariantValue Decode(const VariantMetadata &metadata, const_data_ptr_t data); public: - static VariantValue PrimitiveTypeDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); - static VariantValue ShortStringDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); + static VariantValue PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); + static VariantValue ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ObjectDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ArrayDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp index 27ece7d70..8fe38ce9a 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp @@ -1,6 +1,6 @@ #pragma once -#include "reader/variant/variant_value.hpp" +#include "duckdb/common/types/variant_value.hpp" #include "reader/variant/variant_binary_decoder.hpp" namespace duckdb { @@ -11,13 +11,14 @@ class VariantShreddedConversion { public: static vector Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, idx_t total_size, - bool is_field = false); + bool is_field); static vector ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, - idx_t offset, idx_t length, idx_t total_size); + idx_t offset, idx_t length, idx_t total_size, + const bool is_field); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp index 78670b14a..69b429626 100644 --- a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp @@ -15,7 +15,7 @@ namespace duckdb { class VariantColumnReader : public ColumnReader { public: - static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; + static constexpr const PhysicalType TYPE = PhysicalType::STRUCT; public: VariantColumnReader(ClientContext &context, ParquetReader &reader, const ParquetColumnSchema &schema, diff --git a/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp index 1ebb16c04..404430e1e 100644 --- a/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp @@ -14,9 +14,9 @@ namespace duckdb { class ArrayColumnWriter : public ListColumnWriter { public: - ArrayColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - unique_ptr child_writer_p, bool can_have_nulls) - : ListColumnWriter(writer, column_schema, std::move(schema_path_p), std::move(child_writer_p), can_have_nulls) { + ArrayColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + unique_ptr child_writer_p) + : ListColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writer_p)) { } ~ArrayColumnWriter() override = default; diff --git a/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp index eeaa3d23c..a5606a125 100644 --- a/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp @@ -14,8 +14,7 @@ namespace duckdb { class BooleanColumnWriter : public PrimitiveColumnWriter { public: - BooleanColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - bool can_have_nulls); + BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~BooleanColumnWriter() override = default; public: diff --git a/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp index 38c696571..91ced2899 100644 --- a/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp @@ -14,8 +14,7 @@ namespace duckdb { class FixedDecimalColumnWriter : public PrimitiveColumnWriter { public: - FixedDecimalColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls); + FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~FixedDecimalColumnWriter() override = default; public: diff --git a/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp index 4e3e6e3aa..ba0f6c454 100644 --- a/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp @@ -15,8 +15,7 @@ class EnumWriterPageState; class EnumColumnWriter : public PrimitiveColumnWriter { public: - EnumColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - bool can_have_nulls); + EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~EnumColumnWriter() override = default; uint32_t bit_width; diff --git a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp index f1070b0f1..df7ecf276 100644 --- a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp @@ -26,15 +26,13 @@ class ListColumnWriterState : public ColumnWriterState { class ListColumnWriter : public ColumnWriter { public: - ListColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - unique_ptr child_writer_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writer(std::move(child_writer_p)) { + ListColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + unique_ptr child_writer_p) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { + child_writers.push_back(std::move(child_writer_p)); } ~ListColumnWriter() override = default; - unique_ptr child_writer; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; @@ -46,6 +44,10 @@ class ListColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; + +protected: + ColumnWriter &GetChildWriter(); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp index 1016c81fe..840830e3a 100644 --- a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp +++ b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp @@ -9,7 +9,7 @@ #pragma once #include "column_writer.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" namespace duckdb { @@ -28,7 +28,7 @@ class ColumnWriterStatistics { virtual bool MaxIsExact(); virtual bool HasGeoStats(); - virtual optional_ptr GetGeoStats(); + virtual optional_ptr GetGeoStats(); virtual void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats); public: @@ -255,10 +255,11 @@ class UUIDStatisticsState : public ColumnWriterStatistics { class GeoStatisticsState final : public ColumnWriterStatistics { public: explicit GeoStatisticsState() : has_stats(false) { + geo_stats.SetEmpty(); } bool has_stats; - GeometryStats geo_stats; + GeometryStatsData geo_stats; public: void Update(const string_t &val) { @@ -268,37 +269,36 @@ class GeoStatisticsState final : public ColumnWriterStatistics { bool HasGeoStats() override { return has_stats; } - optional_ptr GetGeoStats() override { + optional_ptr GetGeoStats() override { return geo_stats; } void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats) override { const auto &types = geo_stats.types; - const auto &bbox = geo_stats.bbox; - - if (bbox.IsSet()) { + const auto &bbox = geo_stats.extent; + if (bbox.HasXY()) { stats.__isset.bbox = true; - stats.bbox.xmin = bbox.xmin; - stats.bbox.xmax = bbox.xmax; - stats.bbox.ymin = bbox.ymin; - stats.bbox.ymax = bbox.ymax; + stats.bbox.xmin = bbox.x_min; + stats.bbox.xmax = bbox.x_max; + stats.bbox.ymin = bbox.y_min; + stats.bbox.ymax = bbox.y_max; if (bbox.HasZ()) { stats.bbox.__isset.zmin = true; stats.bbox.__isset.zmax = true; - stats.bbox.zmin = bbox.zmin; - stats.bbox.zmax = bbox.zmax; + stats.bbox.zmin = bbox.z_min; + stats.bbox.zmax = bbox.z_max; } if (bbox.HasM()) { stats.bbox.__isset.mmin = true; stats.bbox.__isset.mmax = true; - stats.bbox.mmin = bbox.mmin; - stats.bbox.mmax = bbox.mmax; + stats.bbox.mmin = bbox.m_min; + stats.bbox.mmax = bbox.m_max; } } stats.__isset.geospatial_types = true; - stats.geospatial_types = types.ToList(); + stats.geospatial_types = types.ToWKBList(); } }; diff --git a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp index 28b217692..36874cf6d 100644 --- a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp @@ -57,8 +57,7 @@ class PrimitiveColumnWriterState : public ColumnWriterState { //! Base class for writing non-compound types (ex. numerics, strings) class PrimitiveColumnWriter : public ColumnWriter { public: - PrimitiveColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path, - bool can_have_nulls); + PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); ~PrimitiveColumnWriter() override = default; //! We limit the uncompressed page size to 100MB @@ -75,6 +74,7 @@ class PrimitiveColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; protected: static void WriteLevels(Allocator &allocator, WriteStream &temp_writer, const unsafe_vector &levels, diff --git a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp index 8927c391b..a3d433467 100644 --- a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp @@ -14,15 +14,13 @@ namespace duckdb { class StructColumnWriter : public ColumnWriter { public: - StructColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, - vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writers(std::move(child_writers_p)) { + StructColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + vector> child_writers_p) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { + child_writers = std::move(child_writers_p); } ~StructColumnWriter() override = default; - vector> child_writers; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; @@ -34,6 +32,7 @@ class StructColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + void FinalizeSchema(vector &schemas) override; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index c035bba43..ea3f516ef 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -116,17 +116,16 @@ class StandardWriterPageState : public ColumnWriterPageState { template class StandardColumnWriter : public PrimitiveColumnWriter { public: - StandardColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, // NOLINT - bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + StandardColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } ~StandardColumnWriter() override = default; public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override { auto result = make_uniq>(writer, row_group, row_group.columns.size()); - result->encoding = duckdb_parquet::Encoding::RLE_DICTIONARY; + result->encoding = writer.GetParquetVersion() == ParquetVersion::V1 ? duckdb_parquet::Encoding::PLAIN_DICTIONARY + : duckdb_parquet::Encoding::RLE_DICTIONARY; RegisterToRowGroup(row_group); return std::move(result); } @@ -150,6 +149,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { } page_state.dbp_encoder.FinishWrite(temp_writer); break; + case duckdb_parquet::Encoding::PLAIN_DICTIONARY: + // PLAIN_DICTIONARY can be treated the same as RLE_DICTIONARY case duckdb_parquet::Encoding::RLE_DICTIONARY: D_ASSERT(page_state.dict_bit_width != 0); if (!page_state.dict_written_value) { @@ -197,6 +198,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { const bool check_parent_empty = parent && !parent->is_empty.empty(); const idx_t parent_index = state.definition_levels.size(); + D_ASSERT(!check_parent_empty || parent_index < parent->is_empty.size()); const idx_t vcount = check_parent_empty ? parent->definition_levels.size() - state.definition_levels.size() : count; @@ -207,7 +209,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { // Fast path for (; vector_index < vcount; vector_index++) { const auto &src_value = data_ptr[vector_index]; - state.dictionary.Insert(src_value); + state.dictionary.template Insert(src_value); state.total_value_count++; state.total_string_size += DlbaEncoder::GetStringSize(src_value); } @@ -218,7 +220,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { } if (validity.RowIsValid(vector_index)) { const auto &src_value = data_ptr[vector_index]; - state.dictionary.Insert(src_value); + state.dictionary.template Insert(src_value); state.total_value_count++; state.total_string_size += DlbaEncoder::GetStringSize(src_value); } @@ -265,7 +267,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { bool HasDictionary(PrimitiveColumnWriterState &state_p) override { auto &state = state_p.Cast>(); - return state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY; + return state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY || + state.encoding == duckdb_parquet::Encoding::PLAIN_DICTIONARY; } idx_t DictionarySize(PrimitiveColumnWriterState &state_p) override { @@ -285,7 +288,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { void FlushDictionary(PrimitiveColumnWriterState &state_p, ColumnWriterStatistics *stats) override { auto &state = state_p.Cast>(); - D_ASSERT(state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY); + D_ASSERT(state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY || + state.encoding == duckdb_parquet::Encoding::PLAIN_DICTIONARY); if (writer.EnableBloomFilters()) { state.bloom_filter = @@ -310,7 +314,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { idx_t GetRowSize(const Vector &vector, const idx_t index, const PrimitiveColumnWriterState &state_p) const override { auto &state = state_p.Cast>(); - if (state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY) { + if (state.encoding == duckdb_parquet::Encoding::RLE_DICTIONARY || + state.encoding == duckdb_parquet::Encoding::PLAIN_DICTIONARY) { return (state.key_bit_width + 7) / 8; } else { return OP::template GetRowSize(vector, index); @@ -328,6 +333,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { const auto *data_ptr = FlatVector::GetData(input_column); switch (page_state.encoding) { + case duckdb_parquet::Encoding::PLAIN_DICTIONARY: + // PLAIN_DICTIONARY can be treated the same as RLE_DICTIONARY case duckdb_parquet::Encoding::RLE_DICTIONARY: { idx_t r = chunk_start; if (!page_state.dict_written_value) { diff --git a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp new file mode 100644 index 000000000..07250c4ac --- /dev/null +++ b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// writer/variant_column_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "struct_column_writer.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +namespace duckdb { + +using variant_type_map = array(VariantLogicalType::ENUM_SIZE)>; + +struct ObjectAnalyzeData; +struct ArrayAnalyzeData; + +struct VariantAnalyzeData { +public: + VariantAnalyzeData() { + } + +public: + //! Map for every value what type it is + variant_type_map type_map = {}; + //! Map for every decimal value what physical type it has + array decimal_type_map = {}; + unique_ptr object_data = nullptr; + unique_ptr array_data = nullptr; +}; + +struct ObjectAnalyzeData { +public: + ObjectAnalyzeData() { + } + +public: + case_insensitive_map_t fields; +}; + +struct ArrayAnalyzeData { +public: + ArrayAnalyzeData() { + } + +public: + VariantAnalyzeData child; +}; + +struct VariantAnalyzeSchemaState : public ParquetAnalyzeSchemaState { +public: + VariantAnalyzeSchemaState() { + } + ~VariantAnalyzeSchemaState() override { + } + +public: + VariantAnalyzeData analyze_data; +}; + +class VariantColumnWriter : public StructColumnWriter { +public: + VariantColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + vector> child_writers_p) + : StructColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writers_p)) { + } + ~VariantColumnWriter() override = default; + +public: + void FinalizeSchema(vector &schemas) override; + unique_ptr AnalyzeSchemaInit() override; + void AnalyzeSchema(ParquetAnalyzeSchemaState &state, Vector &input, idx_t count) override; + void AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state) override; + + bool HasTransform() override { + return true; + } + LogicalType TransformedType() override { + child_list_t children; + for (auto &writer : child_writers) { + auto &child_name = writer->Schema().name; + auto &child_type = writer->Schema().type; + children.emplace_back(child_name, child_type); + } + return LogicalType::STRUCT(std::move(children)); + } + unique_ptr TransformExpression(unique_ptr expr) override { + vector> arguments; + arguments.push_back(unique_ptr_cast(std::move(expr))); + + return make_uniq(TransformedType(), GetTransformFunction(), std::move(arguments), + nullptr, false); + } + +public: + static ScalarFunction GetTransformFunction(); + static LogicalType TransformTypedValueRecursive(const LogicalType &type); + +private: + //! Whether the schema of the variant has been analyzed already + bool is_analyzed = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_column_schema.cpp b/src/duckdb/extension/parquet/parquet_column_schema.cpp new file mode 100644 index 000000000..64409c1ab --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_column_schema.cpp @@ -0,0 +1,113 @@ +#include "parquet_column_schema.hpp" +#include "parquet_reader.hpp" + +namespace duckdb { + +void ParquetColumnSchema::SetSchemaIndex(idx_t schema_idx) { + D_ASSERT(!schema_index.IsValid()); + schema_index = schema_idx; +} + +//! Writer constructors + +ParquetColumnSchema ParquetColumnSchema::FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t column_index, + duckdb_parquet::FieldRepetitionType::type repetition_type, + bool allow_geometry, ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.column_index = column_index; + res.repetition_type = repetition_type; + res.schema_type = schema_type; + res.type = type; + res.allow_geometry = allow_geometry; + return res; +} + +//! Reader constructors + +ParquetColumnSchema ParquetColumnSchema::FromSchemaElement(const duckdb_parquet::SchemaElement &element, + idx_t max_define, idx_t max_repeat, idx_t schema_index, + idx_t column_index, ParquetColumnSchemaType schema_type, + const ParquetOptions &options) { + ParquetColumnSchema res; + res.name = element.name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.schema_index = schema_index; + res.column_index = column_index; + res.schema_type = schema_type; + res.type = ParquetReader::DeriveLogicalType(element, options, res); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FromParentSchema(ParquetColumnSchema parent, LogicalType result_type, + ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = parent.name; + res.max_define = parent.max_define; + res.max_repeat = parent.max_repeat; + D_ASSERT(parent.schema_index.IsValid()); + res.schema_index = parent.schema_index; + res.column_index = parent.column_index; + res.schema_type = schema_type; + res.type = result_type; + res.children.push_back(std::move(parent)); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FromChildSchemas(const string &name, const LogicalType &type, idx_t max_define, + idx_t max_repeat, idx_t schema_index, idx_t column_index, + vector &&children, + ParquetColumnSchemaType schema_type) { + ParquetColumnSchema res; + res.name = name; + res.max_define = max_define; + res.max_repeat = max_repeat; + res.schema_index = schema_index; + res.column_index = column_index; + res.schema_type = schema_type; + res.type = type; + res.children = std::move(children); + return res; +} + +ParquetColumnSchema ParquetColumnSchema::FileRowNumber() { + ParquetColumnSchema res; + res.name = "file_row_number"; + res.max_define = 0; + res.max_repeat = 0; + res.schema_index = 0; + res.column_index = 0; + res.schema_type = ParquetColumnSchemaType::FILE_ROW_NUMBER; + res.type = LogicalType::BIGINT, res.repetition_type = duckdb_parquet::FieldRepetitionType::type::OPTIONAL; + return res; +} + +unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, + const ParquetOptions &parquet_options, idx_t row_group_idx_p, + const vector &columns) const { + if (schema_type == ParquetColumnSchemaType::EXPRESSION) { + return nullptr; + } + if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { + auto stats = NumericStats::CreateUnknown(type); + auto &row_groups = file_meta_data.row_groups; + D_ASSERT(row_group_idx_p < row_groups.size()); + idx_t row_group_offset_min = 0; + for (idx_t i = 0; i < row_group_idx_p; i++) { + row_group_offset_min += row_groups[i].num_rows; + } + + NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); + NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + + row_groups[row_group_idx_p].num_rows))); + stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + return stats.ToUnique(); + } + return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns, parquet_options.can_have_nan); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp index b60c01155..1c72385dc 100644 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ b/src/duckdb/extension/parquet/parquet_crypto.cpp @@ -7,6 +7,11 @@ #include "duckdb/common/helper.hpp" #include "duckdb/common/types/blob.hpp" #include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/common/encryption_functions.hpp" +#include "duckdb/common/allocator.hpp" + +using duckdb_parquet::ColumnChunk; +class Allocator; namespace duckdb { @@ -39,6 +44,77 @@ string ParquetKeys::GetObjectType() { return ObjectType(); } +ParquetAdditionalAuthenticatedData::ParquetAdditionalAuthenticatedData(Allocator &allocator) + : AdditionalAuthenticatedData(allocator) { +} + +ParquetAdditionalAuthenticatedData::~ParquetAdditionalAuthenticatedData() = default; + +idx_t ParquetAdditionalAuthenticatedData::GetPrefixSize() const { + if (!additional_authenticated_data_prefix_size.IsValid()) { + return 0; + } + return additional_authenticated_data_prefix_size.GetIndex(); +} + +void ParquetAdditionalAuthenticatedData::Rewind() const { + additional_authenticated_data->SetPosition(GetPrefixSize()); +} + +void ParquetAdditionalAuthenticatedData::WriteParquetAAD(const CryptoMetaData &crypto_meta_data) { + // For the parquet encryption spec, additional authenticated data (AAD) consists of: + // (1) a unique prefix constructed by: + // an optional aad-prefix (arbitrary length -- ignored for now) + // + a unique file identifier (default 8 bytes) + if (GetPrefixSize() == 0) { + WritePrefix(crypto_meta_data.unique_file_identifier); + } + // (2) a suffix, which length varies according to the module type, consisting of: + // + module type (1 byte) + // + row group ordinal (2 bytes, optional) + // + column ordinal (2 bytes, optional) + // + page ordinal (2 bytes, optional) + WriteSuffix(crypto_meta_data); +} + +void ParquetAdditionalAuthenticatedData::WritePrefix(const std::string &prefix) { + if (prefix.empty()) { + throw InvalidInputException("Prefix for Additional Authenticated Data is empty"); + } + WriteStringData(prefix); + additional_authenticated_data_prefix_size = additional_authenticated_data->GetPosition(); +} + +void ParquetAdditionalAuthenticatedData::WriteSuffix(const CryptoMetaData &crypto_meta_data) { + if (!additional_authenticated_data_prefix_size.IsValid()) { + throw InvalidInputException("Prefix for Parquet additional authenticated data is not set"); + } + + if (crypto_meta_data.module < 0) { + throw InvalidInputException("Parquet Crypto Module not initialized"); + } + WriteData(crypto_meta_data.module); + + if (crypto_meta_data.row_group_ordinal < 0) { + if (crypto_meta_data.module != ParquetCrypto::FOOTER) { + throw InvalidInputException("Parquet Encryption: Row group not initialized"); + } + // Footer + return; + } + WriteData(crypto_meta_data.row_group_ordinal); + + if (crypto_meta_data.column_ordinal < 0) { + return; + } + WriteData(crypto_meta_data.column_ordinal); + + if (crypto_meta_data.page_ordinal < 0) { + return; + } + WriteData(crypto_meta_data.page_ordinal); +} + ParquetEncryptionConfig::ParquetEncryptionConfig() { } @@ -171,12 +247,14 @@ class EncryptionTransport : public TTransport { //! Decryption wrapper for a transport protocol class DecryptionTransport : public TTransport { public: - DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) + DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p, + const CryptoMetaData &crypto_meta_data) : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState(EncryptionTypes::GCM, key.size())), read_buffer_size(0), read_buffer_offset(0) { - Initialize(key); + Initialize(key, crypto_meta_data); } + uint32_t read_virt(uint8_t *buf, uint32_t len) override { const uint32_t result = len; @@ -198,7 +276,6 @@ class DecryptionTransport : public TTransport { } uint32_t Finalize() { - if (read_buffer_offset != read_buffer_size) { throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" "read buffer offset: %d, read buffer size: %d", @@ -225,7 +302,7 @@ class DecryptionTransport : public TTransport { } private: - void Initialize(const string &key) { + void Initialize(const string &key, const CryptoMetaData &crypto_meta_data) { // Read encoded length (don't add to read_bytes) data_t length_buf[ParquetCrypto::LENGTH_BYTES]; trans.read(length_buf, ParquetCrypto::LENGTH_BYTES); @@ -234,8 +311,15 @@ class DecryptionTransport : public TTransport { // Read nonce and initialize AES transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES); // check whether context is initialized - aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast(key.data()), - key.size()); + if (!crypto_meta_data.IsEmpty()) { + aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast(key.data()), + key.size(), crypto_meta_data.additional_authenticated_data->data(), + crypto_meta_data.additional_authenticated_data->size()); + crypto_meta_data.additional_authenticated_data->Rewind(); + } else { + aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast(key.data()), + key.size()); + } } void ReadBlock(uint8_t *buf) { @@ -298,10 +382,10 @@ class SimpleReadTransport : public TTransport { }; uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key, - const EncryptionUtil &encryption_util_p) { + const EncryptionUtil &encryption_util_p, const CryptoMetaData &crypto_meta_data) { TCompactProtocolFactoryT tproto_factory; - auto dprot = - tproto_factory.getProtocol(duckdb_base_std::make_shared(iprot, key, encryption_util_p)); + auto dprot = tproto_factory.getProtocol( + duckdb_base_std::make_shared(iprot, key, encryption_util_p, crypto_meta_data)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong @@ -332,11 +416,12 @@ uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const strin } uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, - const string &key, const EncryptionUtil &encryption_util_p) { + const string &key, const EncryptionUtil &encryption_util_p, + const CryptoMetaData &crypto_meta_data) { // Create decryption protocol TCompactProtocolFactoryT tproto_factory; - auto dprot = - tproto_factory.getProtocol(duckdb_base_std::make_shared(iprot, key, encryption_util_p)); + auto dprot = tproto_factory.getProtocol( + duckdb_base_std::make_shared(iprot, key, encryption_util_p, crypto_meta_data)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // Read buffer @@ -362,6 +447,73 @@ uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffe return etrans.Finalize(); } +int8_t ParquetCrypto::GetModuleHeader(const ColumnChunk &chunk, uint16_t page_ordinal) { + if (page_ordinal > 0) { + // always return data page header if ordinal > 0 + return DATA_PAGE_HEADER; + } + // There is at maximum 1 dictionary, index or bf filter page header per column chunk + if (chunk.meta_data.__isset.dictionary_page_offset) { + return DICTIONARY_PAGE_HEADER; + } else if (chunk.meta_data.__isset.index_page_offset) { + return OFFSET_INDEX; + } else if (chunk.meta_data.__isset.bloom_filter_offset) { + return ParquetCrypto::BLOOM_FILTER_HEADER; + } + + return DATA_PAGE_HEADER; +} + +int8_t ParquetCrypto::GetModule(const ColumnChunk &chunk, PageType::type page_type, uint16_t page_ordinal) { + if (chunk.meta_data.__isset.bloom_filter_offset && page_ordinal == 0) { + // return bitset if it is the first page ordinal + return ParquetCrypto::BLOOM_FILTER_BITSET; + } + + switch (page_type) { + case PageType::DATA_PAGE: + case PageType::DATA_PAGE_V2: + return DATA_PAGE; + case PageType::DICTIONARY_PAGE: + return DICTIONARY_PAGE; + case PageType::INDEX_PAGE: + if (chunk.meta_data.__isset.index_page_offset) { + return OFFSET_INDEX; + } + return COLUMN_INDEX; + default: + throw InvalidInputException("Module not found"); + } +} + +int16_t ParquetCrypto::GetFinalPageOrdinal(const ColumnChunk &chunk, uint8_t module, uint16_t page_ordinal) { + switch (module) { + case DATA_PAGE_HEADER: + if (chunk.meta_data.__isset.dictionary_page_offset) { + page_ordinal -= 1; + } else if (chunk.meta_data.__isset.index_page_offset) { + page_ordinal -= 1; + } else if (chunk.meta_data.__isset.bloom_filter_offset) { + page_ordinal -= 1; + } + return page_ordinal; + case DATA_PAGE: + return page_ordinal; + default: + // All modules except DataPage(Header) are -1 (absent) + return -1; + } +} + +void ParquetCrypto::GenerateAdditionalAuthenticatedData(Allocator &allocator, CryptoMetaData &aad_crypto_metadata) { + if (aad_crypto_metadata.IsEmpty()) { + // no aad, old duckdb-parquet crypto implementation + aad_crypto_metadata.ClearAdditionalAuthenticatedData(); + return; + } + aad_crypto_metadata.additional_authenticated_data->WriteParquetAAD(aad_crypto_metadata); +} + bool ParquetCrypto::ValidKey(const std::string &key) { switch (key.size()) { case 16: @@ -403,4 +555,34 @@ void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters &par } } +CryptoMetaData::CryptoMetaData(Allocator &allocator) { + additional_authenticated_data = make_uniq(allocator); +} + +void CryptoMetaData::Initialize(const std::string &unique_file_identifier_p, int16_t row_group_ordinal_p, + int16_t column_ordinal_p, int8_t module_p, int16_t page_ordinal_p) { + if (unique_file_identifier_p.empty()) { + // aad not used for encryption + // this happens with old duckdb-parquet encryption + return; + } + unique_file_identifier = unique_file_identifier_p; + module = module_p; + row_group_ordinal = row_group_ordinal_p; + column_ordinal = column_ordinal_p; + page_ordinal = page_ordinal_p; +} + +void CryptoMetaData::SetModule(int8_t module_p) { + module = module_p; +} + +void CryptoMetaData::ClearAdditionalAuthenticatedData() { + additional_authenticated_data = nullptr; +} + +bool CryptoMetaData::IsEmpty() const { + return unique_file_identifier.empty(); +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 37e6cd0b7..59102a8eb 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -7,14 +7,16 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "parquet_crypto.hpp" #include "parquet_metadata.hpp" #include "parquet_reader.hpp" #include "parquet_writer.hpp" +#include "parquet_shredding.hpp" #include "reader/struct_column_reader.hpp" #include "zstd_file_system.hpp" #include "writer/primitive_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include #include @@ -43,6 +45,9 @@ #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" @@ -54,156 +59,6 @@ namespace duckdb { -static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { - case_insensitive_map_t name_to_type_map; - switch (type.id()) { - case LogicalTypeId::LIST: - name_to_type_map.emplace("element", ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - name_to_type_map.emplace("key", MapType::KeyType(type)); - name_to_type_map.emplace("value", MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - if (child_type.first == FieldID::DUCKDB_FIELD_ID) { - throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); - } - name_to_type_map.emplace(child_type); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNameToTypeMap"); - } // LCOV_EXCL_STOP - return name_to_type_map; -} - -static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, - vector &child_types) { - switch (type.id()) { - case LogicalTypeId::LIST: - child_names.emplace_back("element"); - child_types.emplace_back(ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - child_names.emplace_back("key"); - child_names.emplace_back("value"); - child_types.emplace_back(MapType::KeyType(type)); - child_types.emplace_back(MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - child_names.emplace_back(child_type.first); - child_types.emplace_back(child_type.second); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNamesAndTypes"); - } // LCOV_EXCL_STOP -} - -static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, - const vector &sql_types) { - D_ASSERT(names.size() == sql_types.size()); - for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - const auto &col_name = names[col_idx]; - auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); - D_ASSERT(inserted.second); - - const auto &col_type = sql_types[col_idx]; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - continue; - } - - // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first - vector child_names; - vector child_types; - GetChildNamesAndTypes(col_type, child_names, child_types); - - GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); - } -} - -static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, - unordered_set &unique_field_ids, - const case_insensitive_map_t &name_to_type_map) { - const auto &struct_type = field_ids_value.type(); - if (struct_type.id() != LogicalTypeId::STRUCT) { - throw BinderException( - "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", - FieldID::DUCKDB_FIELD_ID); - } - const auto &struct_children = StructValue::GetChildren(field_ids_value); - D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); - if (col_name == FieldID::DUCKDB_FIELD_ID) { - continue; - } - - auto it = name_to_type_map.find(col_name); - if (it == name_to_type_map.end()) { - string names; - for (const auto &name : name_to_type_map) { - if (!names.empty()) { - names += ", "; - } - names += name.first; - } - throw BinderException( - "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " - "column is a partition column. Available column names: [%s]", - col_name, names); - } - D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys - - const auto &child_value = struct_children[i]; - const auto &child_type = child_value.type(); - optional_ptr field_id_value; - optional_ptr child_field_ids_value; - - if (child_type.id() == LogicalTypeId::STRUCT) { - const auto &nested_children = StructValue::GetChildren(child_value); - D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); - for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { - const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); - if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { - field_id_value = &nested_children[nested_i]; - } else { - child_field_ids_value = &child_value; - } - } - } else { - field_id_value = &child_value; - } - - FieldID field_id; - if (field_id_value) { - Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); - const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); - if (!unique_field_ids.insert(field_id_int).second) { - throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); - } - field_id = FieldID(UnsafeNumericCast(field_id_int)); - } - auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); - D_ASSERT(inserted.second); - - if (child_field_ids_value) { - const auto &col_type = it->second; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", - col_name, LogicalTypeIdToString(col_type.id())); - } - - GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, - GetChildNameToTypeMap(col_type)); - } - } -} - struct ParquetWriteBindData : public TableFunctionData { vector sql_types; vector column_names; @@ -233,41 +88,33 @@ struct ParquetWriteBindData : public TableFunctionData { optional_idx row_groups_per_file; ChildFieldIDs field_ids; + ShreddingType shredding_types; //! The compression level, higher value is more int64_t compression_level = ZStdFileSystem::DefaultCompressionLevel(); //! Which encodings to include when writing ParquetVersion parquet_version = ParquetVersion::V1; -}; -struct ParquetWriteGlobalState : public GlobalFunctionData { - unique_ptr writer; - optional_ptr op; - - void LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason) { - if (!op) { - return; - } - DUCKDB_LOG(writer->GetContext(), PhysicalOperatorLogType, *op, "ParquetWriter", "FlushRowGroup", - {{"file", writer->GetFileName()}, - {"rows", to_string(buffer.Count())}, - {"size", to_string(buffer.SizeInBytes())}, - {"reason", reason}}); - } - - mutex lock; - unique_ptr combine_buffer; + //! Which geo-parquet version to use when writing + GeoParquetVersion geoparquet_version = GeoParquetVersion::V1; }; -struct ParquetWriteLocalState : public LocalFunctionData { - explicit ParquetWriteLocalState(ClientContext &context, const vector &types) : buffer(context, types) { - buffer.SetPartitionIndex(0); // Makes the buffer manager less likely to spill this data - buffer.InitializeAppend(append_state); +void ParquetWriteGlobalState::LogFlushingRowGroup(const ColumnDataCollection &buffer, const string &reason) { + if (!op) { + return; } + DUCKDB_LOG(writer->GetContext(), PhysicalOperatorLogType, *op, "ParquetWriter", "FlushRowGroup", + {{"file", writer->GetFileName()}, + {"rows", to_string(buffer.Count())}, + {"size", to_string(buffer.SizeInBytes())}, + {"reason", reason}}); +} - ColumnDataCollection buffer; - ColumnDataAppendState append_state; -}; +ParquetWriteLocalState::ParquetWriteLocalState(ClientContext &context, const vector &types) + : buffer(context, types) { + buffer.SetPartitionIndex(0); // Makes the buffer manager less likely to spill this data + buffer.InitializeAppend(append_state); +} static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &input) { auto ©_options = input.options; @@ -291,6 +138,8 @@ static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &inp copy_options["binary_as_string"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["file_row_number"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["can_have_nan"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); + copy_options["geoparquet_version"] = CopyOption(LogicalType::VARCHAR, CopyOptionMode::WRITE_ONLY); + copy_options["shredding"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); } static unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBindInput &input, @@ -342,7 +191,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun if (option.second[0].type().id() == LogicalTypeId::VARCHAR && StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { idx_t field_id = 0; - GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); + FieldID::GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); } else { unordered_set unique_field_ids; case_insensitive_map_t name_to_type_map; @@ -353,7 +202,54 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } name_to_type_map.emplace(names[col_idx], sql_types[col_idx]); } - GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + FieldID::GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + } + } else if (loption == "shredding") { + if (option.second[0].type().id() == LogicalTypeId::VARCHAR && + StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { + throw NotImplementedException("The 'auto' option is not yet implemented for 'shredding'"); + } else { + case_insensitive_set_t variant_names; + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + if (sql_types[col_idx].id() != LogicalTypeId::VARIANT) { + continue; + } + variant_names.emplace(names[col_idx]); + } + auto &shredding_types_value = option.second[0]; + if (shredding_types_value.type().id() != LogicalTypeId::STRUCT) { + BinderException("SHREDDING value should be a STRUCT of column names to types, i.e: {col1: " + "'INTEGER[]', col2: 'BOOLEAN'}"); + } + const auto &struct_type = shredding_types_value.type(); + const auto &struct_children = StructValue::GetChildren(shredding_types_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + auto it = variant_names.find(col_name); + if (it == variant_names.end()) { + string names; + for (const auto &entry : variant_names) { + if (!names.empty()) { + names += ", "; + } + names += entry; + } + if (names.empty()) { + throw BinderException("VARIANT by name \"%s\" specified in SHREDDING not found. There are " + "no VARIANT columns present.", + col_name); + } else { + throw BinderException( + "VARIANT by name \"%s\" specified in SHREDDING not found. Consider using " + "WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available names of VARIANT columns: [%s]", + col_name, names); + } + } + const auto &child_value = struct_children[i]; + bind_data->shredding_types.AddChild(col_name, ShreddingType::GetShreddingTypes(child_value)); + } } } else if (loption == "kv_metadata") { auto &kv_struct = option.second[0]; @@ -426,6 +322,19 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } else { throw BinderException("Expected parquet_version 'V1' or 'V2'"); } + } else if (loption == "geoparquet_version") { + const auto roption = StringUtil::Upper(option.second[0].ToString()); + if (roption == "NONE") { + bind_data->geoparquet_version = GeoParquetVersion::NONE; + } else if (roption == "V1") { + bind_data->geoparquet_version = GeoParquetVersion::V1; + } else if (roption == "V2") { + bind_data->geoparquet_version = GeoParquetVersion::V2; + } else if (roption == "BOTH") { + bind_data->geoparquet_version = GeoParquetVersion::BOTH; + } else { + throw BinderException("Expected geoparquet_version 'NONE', 'V1' or 'BOTH'"); + } } else { throw InternalException("Unrecognized option for PARQUET: %s", option.first.c_str()); } @@ -454,10 +363,11 @@ static unique_ptr ParquetWriteInitializeGlobal(ClientContext auto &fs = FileSystem::GetFileSystem(context); global_state->writer = make_uniq( context, fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, parquet_bind.codec, - parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, parquet_bind.encryption_config, - parquet_bind.dictionary_size_limit, parquet_bind.string_dictionary_page_size_limit, - parquet_bind.enable_bloom_filters, parquet_bind.bloom_filter_false_positive_ratio, - parquet_bind.compression_level, parquet_bind.debug_use_openssl, parquet_bind.parquet_version); + parquet_bind.field_ids.Copy(), parquet_bind.shredding_types.Copy(), parquet_bind.kv_metadata, + parquet_bind.encryption_config, parquet_bind.dictionary_size_limit, + parquet_bind.string_dictionary_page_size_limit, parquet_bind.enable_bloom_filters, + parquet_bind.bloom_filter_false_positive_ratio, parquet_bind.compression_level, parquet_bind.debug_use_openssl, + parquet_bind.parquet_version, parquet_bind.geoparquet_version); return std::move(global_state); } @@ -483,7 +393,7 @@ static void ParquetWriteSink(ExecutionContext &context, FunctionData &bind_data_ global_state.LogFlushingRowGroup(local_state.buffer, reason); // if the chunk collection exceeds a certain size (rows/bytes) we flush it to the parquet file local_state.append_state.current_chunk_state.handles.clear(); - global_state.writer->Flush(local_state.buffer); + global_state.writer->Flush(local_state.buffer, local_state.transform_data); local_state.buffer.InitializeAppend(local_state.append_state); } } @@ -498,7 +408,7 @@ static void ParquetWriteCombine(ExecutionContext &context, FunctionData &bind_da local_state.buffer.SizeInBytes() >= bind_data.row_group_size_bytes / 2) { // local state buffer is more than half of the row_group_size(_bytes), just flush it global_state.LogFlushingRowGroup(local_state.buffer, "Combine"); - global_state.writer->Flush(local_state.buffer); + global_state.writer->Flush(local_state.buffer, local_state.transform_data); return; } @@ -513,7 +423,7 @@ static void ParquetWriteCombine(ExecutionContext &context, FunctionData &bind_da guard.unlock(); global_state.LogFlushingRowGroup(*owned_combine_buffer, "Combine"); // Lock free, of course - global_state.writer->Flush(*owned_combine_buffer); + global_state.writer->Flush(*owned_combine_buffer, local_state.transform_data); } return; } @@ -527,7 +437,7 @@ static void ParquetWriteFinalize(ClientContext &context, FunctionData &bind_data // flush the combine buffer (if it's there) if (global_state.combine_buffer) { global_state.LogFlushingRowGroup(*global_state.combine_buffer, "Finalize"); - global_state.writer->Flush(*global_state.combine_buffer); + global_state.writer->Flush(*global_state.combine_buffer, global_state.transform_data); } // finalize: write any additional metadata to the file here @@ -626,6 +536,39 @@ ParquetVersion EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template <> +const char *EnumUtil::ToChars(GeoParquetVersion value) { + switch (value) { + case GeoParquetVersion::NONE: + return "NONE"; + case GeoParquetVersion::V1: + return "V1"; + case GeoParquetVersion::V2: + return "V2"; + case GeoParquetVersion::BOTH: + return "BOTH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); + } +} + +template <> +GeoParquetVersion EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return GeoParquetVersion::NONE; + } + if (StringUtil::Equals(value, "V1")) { + return GeoParquetVersion::V1; + } + if (StringUtil::Equals(value, "V2")) { + return GeoParquetVersion::V2; + } + if (StringUtil::Equals(value, "BOTH")) { + return GeoParquetVersion::BOTH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + static optional_idx SerializeCompressionLevel(const int64_t compression_level) { return compression_level < 0 ? NumericLimits::Maximum() - NumericCast(AbsValue(compression_level)) : NumericCast(compression_level); @@ -679,6 +622,9 @@ static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bin serializer.WritePropertyWithDefault(115, "string_dictionary_page_size_limit", bind_data.string_dictionary_page_size_limit, default_value.string_dictionary_page_size_limit); + serializer.WritePropertyWithDefault(116, "geoparquet_version", bind_data.geoparquet_version, + default_value.geoparquet_version); + serializer.WriteProperty(117, "shredding_types", bind_data.shredding_types); } static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { @@ -711,6 +657,9 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize deserializer.ReadPropertyWithExplicitDefault(114, "parquet_version", default_value.parquet_version); data->string_dictionary_page_size_limit = deserializer.ReadPropertyWithExplicitDefault( 115, "string_dictionary_page_size_limit", default_value.string_dictionary_page_size_limit); + data->geoparquet_version = + deserializer.ReadPropertyWithExplicitDefault(116, "geoparquet_version", default_value.geoparquet_version); + data->shredding_types = deserializer.ReadProperty(117, "shredding_types"); return std::move(data); } @@ -747,7 +696,8 @@ static unique_ptr ParquetWritePrepareBatch(ClientContext &con unique_ptr collection) { auto &global_state = gstate.Cast(); auto result = make_uniq(); - global_state.writer->PrepareRowGroup(*collection, result->prepared_row_group); + unique_ptr transform_data; + global_state.writer->PrepareRowGroup(*collection, result->prepared_row_group, transform_data); return std::move(result); } @@ -828,8 +778,20 @@ static bool IsTypeLossy(const LogicalType &type) { return type.id() == LogicalTypeId::HUGEINT || type.id() == LogicalTypeId::UHUGEINT; } -static vector> ParquetWriteSelect(CopyToSelectInput &input) { +static bool IsExtensionGeometryType(const LogicalType &type, ClientContext &context) { + if (type.id() != LogicalTypeId::BLOB) { + return false; + } + if (!type.HasAlias()) { + return false; + } + if (type.GetAlias() != "GEOMETRY") { + return false; + } + return GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context); +} +static vector> ParquetWriteSelect(CopyToSelectInput &input) { auto &context = input.context; vector> result; @@ -837,19 +799,15 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu bool any_change = false; for (auto &expr : input.select_list) { - const auto &type = expr->return_type; const auto &name = expr->GetAlias(); // Spatial types need to be encoded into WKB when writing GeoParquet. // But dont perform this conversion if this is a EXPORT DATABASE statement - if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::BLOB && type.HasAlias() && - type.GetAlias() == "GEOMETRY" && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - - LogicalType wkb_blob_type(LogicalTypeId::BLOB); - wkb_blob_type.SetAlias("WKB_BLOB"); - - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), wkb_blob_type, false); + if (input.copy_to_type == CopyToType::COPY_TO_FILE && IsExtensionGeometryType(type, context)) { + // Cast the column to GEOMETRY + auto cast_expr = + BoundCastExpression::AddCastToType(context, std::move(expr), LogicalType::GEOMETRY(), false); cast_expr->SetAlias(name); result.push_back(std::move(cast_expr)); any_change = true; @@ -924,6 +882,13 @@ static void LoadInternal(ExtensionLoader &loader) { ParquetBloomProbeFunction bloom_probe_fun; loader.RegisterFunction(MultiFileReader::CreateFunctionSet(bloom_probe_fun)); + // parquet_full_metadata + ParquetFullMetadataFunction full_meta_fun; + loader.RegisterFunction(MultiFileReader::CreateFunctionSet(full_meta_fun)); + + // variant_to_parquet_variant + loader.RegisterFunction(VariantColumnWriter::GetTransformFunction()); + CopyFunction function("parquet"); function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; @@ -970,9 +935,6 @@ static void LoadInternal(ExtensionLoader &loader) { "enable_geoparquet_conversion", "Attempt to decode/encode geometry data in/as GeoParquet files if the spatial extension is present.", LogicalType::BOOLEAN, Value::BOOLEAN(true)); - config.AddExtensionOption("variant_legacy_encoding", - "Enables the Parquet reader to identify a Variant structurally.", LogicalType::BOOLEAN, - Value::BOOLEAN(false)); } void ParquetExtension::Load(ExtensionLoader &loader) { diff --git a/src/duckdb/extension/parquet/parquet_field_id.cpp b/src/duckdb/extension/parquet/parquet_field_id.cpp new file mode 100644 index 000000000..642fc26c7 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_field_id.cpp @@ -0,0 +1,180 @@ +#include "parquet_field_id.hpp" +#include "duckdb/common/exception/binder_exception.hpp" + +namespace duckdb { + +constexpr const char *FieldID::DUCKDB_FIELD_ID; + +ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { +} + +ChildFieldIDs ChildFieldIDs::Copy() const { + ChildFieldIDs result; + for (const auto &id : *ids) { + result.ids->emplace(id.first, id.second.Copy()); + } + return result; +} + +FieldID::FieldID() : set(false) { +} + +FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { +} + +FieldID FieldID::Copy() const { + auto result = set ? FieldID(field_id) : FieldID(); + result.child_field_ids = child_field_ids.Copy(); + return result; +} + +static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { + case_insensitive_map_t name_to_type_map; + switch (type.id()) { + case LogicalTypeId::LIST: + name_to_type_map.emplace("element", ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + name_to_type_map.emplace("key", MapType::KeyType(type)); + name_to_type_map.emplace("value", MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + if (child_type.first == FieldID::DUCKDB_FIELD_ID) { + throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); + } + name_to_type_map.emplace(child_type); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNameToTypeMap"); + } // LCOV_EXCL_STOP + return name_to_type_map; +} + +static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, + vector &child_types) { + switch (type.id()) { + case LogicalTypeId::LIST: + child_names.emplace_back("element"); + child_types.emplace_back(ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + child_names.emplace_back("key"); + child_names.emplace_back("value"); + child_types.emplace_back(MapType::KeyType(type)); + child_types.emplace_back(MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + child_names.emplace_back(child_type.first); + child_types.emplace_back(child_type.second); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNamesAndTypes"); + } // LCOV_EXCL_STOP +} + +void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types) { + D_ASSERT(names.size() == sql_types.size()); + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + const auto &col_name = names[col_idx]; + auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); + D_ASSERT(inserted.second); + + const auto &col_type = sql_types[col_idx]; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + continue; + } + + // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first + vector child_names; + vector child_types; + GetChildNamesAndTypes(col_type, child_names, child_types); + GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); + } +} + +void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map) { + const auto &struct_type = field_ids_value.type(); + if (struct_type.id() != LogicalTypeId::STRUCT) { + throw BinderException( + "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", + FieldID::DUCKDB_FIELD_ID); + } + const auto &struct_children = StructValue::GetChildren(field_ids_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + if (col_name == FieldID::DUCKDB_FIELD_ID) { + continue; + } + + auto it = name_to_type_map.find(col_name); + if (it == name_to_type_map.end()) { + string names; + for (const auto &name : name_to_type_map) { + if (!names.empty()) { + names += ", "; + } + names += name.first; + } + throw BinderException( + "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available column names: [%s]", + col_name, names); + } + D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys + + const auto &child_value = struct_children[i]; + const auto &child_type = child_value.type(); + optional_ptr field_id_value; + optional_ptr child_field_ids_value; + + if (child_type.id() == LogicalTypeId::STRUCT) { + const auto &nested_children = StructValue::GetChildren(child_value); + D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); + for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { + const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); + if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { + field_id_value = &nested_children[nested_i]; + } else { + child_field_ids_value = &child_value; + } + } + } else { + field_id_value = &child_value; + } + + FieldID field_id; + if (field_id_value) { + Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); + const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); + if (!unique_field_ids.insert(field_id_int).second) { + throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); + } + field_id = FieldID(UnsafeNumericCast(field_id_int)); + } + auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); + D_ASSERT(inserted.second); + + if (child_field_ids_value) { + const auto &col_type = it->second; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", + col_name, LogicalTypeIdToString(col_type.id())); + } + + GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, + GetChildNameToTypeMap(col_type)); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp b/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp index 09a69ce57..5d51c81e6 100644 --- a/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp +++ b/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp @@ -6,9 +6,11 @@ namespace duckdb { ParquetFileMetadataCache::ParquetFileMetadataCache(unique_ptr file_metadata, CachingFileHandle &handle, - unique_ptr geo_metadata, idx_t footer_size) - : metadata(std::move(file_metadata)), geo_metadata(std::move(geo_metadata)), footer_size(footer_size), - validate(handle.Validate()), last_modified(handle.GetLastModifiedTime()), version_tag(handle.GetVersionTag()) { + unique_ptr geo_metadata, + unique_ptr crypto_metadata, idx_t footer_size) + : metadata(std::move(file_metadata)), geo_metadata(std::move(geo_metadata)), + crypto_metadata(std::move(crypto_metadata)), footer_size(footer_size), validate(handle.Validate()), + last_modified(handle.GetLastModifiedTime()), version_tag(handle.GetVersionTag()) { } string ParquetFileMetadataCache::ObjectType() { diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/parquet_geometry.cpp similarity index 54% rename from src/duckdb/extension/parquet/geo_parquet.cpp rename to src/duckdb/extension/parquet/parquet_geometry.cpp index bddc36b43..7ab81cc2a 100644 --- a/src/duckdb/extension/parquet/geo_parquet.cpp +++ b/src/duckdb/extension/parquet/parquet_geometry.cpp @@ -1,193 +1,29 @@ -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "column_reader.hpp" #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/scalar_function.hpp" +#include "duckdb/function/scalar/geometry_functions.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/main/extension_helper.hpp" #include "reader/expression_column_reader.hpp" #include "parquet_reader.hpp" #include "yyjson.hpp" +#include "reader/string_column_reader.hpp" namespace duckdb { using namespace duckdb_yyjson; // NOLINT -//------------------------------------------------------------------------------ -// WKB stats -//------------------------------------------------------------------------------ -namespace { - -class BinaryReader { -public: - const char *beg; - const char *end; - const char *ptr; - - BinaryReader(const char *beg, uint32_t len) : beg(beg), end(beg + len), ptr(beg) { - } - - template - T Read() { - if (ptr + sizeof(T) > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - T val; - memcpy(&val, ptr, sizeof(T)); - ptr += sizeof(T); - return val; - } - - void Skip(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - ptr += len; - } - - const char *Reserve(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - auto ret = ptr; - ptr += len; - return ret; - } - - bool IsAtEnd() const { - return ptr >= end; - } -}; - -} // namespace - -static void UpdateBoundsFromVertexArray(GeometryExtent &bbox, uint32_t flag, const char *vert_array, - uint32_t vert_count) { - switch (flag) { - case 0: { // XY - constexpr auto vert_width = sizeof(double) * 2; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[2]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - } - } break; - case 1: { // XYZ - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - } - } break; - case 2: { // XYM - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendM(vert[2]); - } - } break; - case 3: { // XYZM - constexpr auto vert_width = sizeof(double) * 4; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[4]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - bbox.ExtendM(vert[3]); - } - } break; - default: - break; - } -} - -void GeometryStats::Update(const string_t &wkb) { - BinaryReader reader(wkb.GetData(), wkb.GetSize()); - - bool first_geom = true; - while (!reader.IsAtEnd()) { - reader.Read(); // byte order - auto type = reader.Read(); - auto kind = type % 1000; - auto flag = type / 1000; - const auto hasz = (flag & 0x01) != 0; - const auto hasm = (flag & 0x02) != 0; - - if (first_geom) { - // Only add the top-level geometry type - types.Add(type); - first_geom = false; - } - - const auto vert_width = sizeof(double) * (2 + (hasz ? 1 : 0) + (hasm ? 1 : 0)); - - switch (kind) { - case 1: { // POINT - - // Point are special in that they are considered "empty" if they are all-nan - const auto vert_array = reader.Reserve(vert_width); - const auto dims_count = 2 + (hasz ? 1 : 0) + (hasm ? 1 : 0); - double vert_point[4] = {0, 0, 0, 0}; - - memcpy(vert_point, vert_array, vert_width); - - for (auto dim_idx = 0; dim_idx < dims_count; dim_idx++) { - if (!std::isnan(vert_point[dim_idx])) { - bbox.ExtendX(vert_point[0]); - bbox.ExtendY(vert_point[1]); - if (hasz && hasm) { - bbox.ExtendZ(vert_point[2]); - bbox.ExtendM(vert_point[3]); - } else if (hasz) { - bbox.ExtendZ(vert_point[2]); - } else if (hasm) { - bbox.ExtendM(vert_point[2]); - } - break; - } - } - } break; - case 2: { // LINESTRING - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } break; - case 3: { // POLYGON - const auto ring_count = reader.Read(); - for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } - } break; - case 4: // MULTIPOINT - case 5: // MULTILINESTRING - case 6: // MULTIPOLYGON - case 7: { // GEOMETRYCOLLECTION - reader.Skip(sizeof(uint32_t)); - } break; - } - } -} - //------------------------------------------------------------------------------ // GeoParquetFileMetadata //------------------------------------------------------------------------------ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_parquet::FileMetaData &file_meta_data, const ClientContext &context) { - // Conversion not enabled, or spatial is not loaded! if (!IsGeoParquetConversionEnabled(context)) { return nullptr; @@ -208,17 +44,19 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ throw InvalidInputException("Geoparquet metadata is not an object"); } - auto result = make_uniq(); + // We dont actually care about the version for now, as we only support V1+native + auto result = make_uniq(GeoParquetVersion::BOTH); // Check and parse the version const auto version_val = yyjson_obj_get(root, "version"); if (!yyjson_is_str(version_val)) { throw InvalidInputException("Geoparquet metadata does not have a version"); } - result->version = yyjson_get_str(version_val); - if (StringUtil::StartsWith(result->version, "2")) { - // Guard against a breaking future 2.0 version - throw InvalidInputException("Geoparquet version %s is not supported", result->version); + + auto version = yyjson_get_str(version_val); + if (StringUtil::StartsWith(version, "3")) { + // Guard against a breaking future 3.0 version + throw InvalidInputException("Geoparquet version %s is not supported", version); } // Check and parse the geometry columns @@ -292,8 +130,7 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ } void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const LogicalType &type, - const GeometryStats &stats) { - + const GeometryStatsData &stats) { // Lock the metadata lock_guard glock(write_lock); @@ -301,21 +138,18 @@ void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const if (it == geometry_columns.end()) { auto &column = geometry_columns[column_name]; - column.stats.types.Combine(stats.types); - column.stats.bbox.Combine(stats.bbox); + column.stats.Merge(stats); column.insertion_index = geometry_columns.size() - 1; } else { - it->second.stats.types.Combine(stats.types); - it->second.stats.bbox.Combine(stats.bbox); + it->second.stats.Merge(stats); } } void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) { - // GeoParquet does not support M or ZM coordinates. So remove any columns that have them. unordered_set invalid_columns; for (auto &column : geometry_columns) { - if (column.second.stats.bbox.HasM()) { + if (column.second.stats.extent.HasM()) { invalid_columns.insert(column.first); } } @@ -344,7 +178,20 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) yyjson_mut_doc_set_root(doc, root); // Add the version - yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); + switch (version) { + case GeoParquetVersion::V1: + case GeoParquetVersion::BOTH: + yyjson_mut_obj_add_strcpy(doc, root, "version", "1.0.0"); + break; + case GeoParquetVersion::V2: + yyjson_mut_obj_add_strcpy(doc, root, "version", "2.0.0"); + break; + case GeoParquetVersion::NONE: + default: + // Should never happen, we should not be writing anything + yyjson_mut_doc_free(doc); + throw InternalException("GeoParquetVersion::NONE should not write metadata"); + } // Add the primary column yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), @@ -354,32 +201,31 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) const auto json_columns = yyjson_mut_obj_add_obj(doc, root, "columns"); for (auto &column : geometry_columns) { - const auto column_json = yyjson_mut_obj_add_obj(doc, json_columns, column.first.c_str()); yyjson_mut_obj_add_str(doc, column_json, "encoding", "WKB"); const auto geometry_types = yyjson_mut_obj_add_arr(doc, column_json, "geometry_types"); + for (auto &type_name : column.second.stats.types.ToString(false)) { yyjson_mut_arr_add_strcpy(doc, geometry_types, type_name.c_str()); } - const auto &bbox = column.second.stats.bbox; - - if (bbox.IsSet()) { + const auto &bbox = column.second.stats.extent; + if (bbox.HasXY()) { const auto bbox_arr = yyjson_mut_obj_add_arr(doc, column_json, "bbox"); - if (!column.second.stats.bbox.HasZ()) { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); + if (!column.second.stats.extent.HasZ()) { + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); } else { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmax); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_max); } } @@ -432,52 +278,31 @@ bool GeoParquetFileMetadata::IsGeoParquetConversionEnabled(const ClientContext & // Disabled by setting return false; } - if (!context.db->ExtensionIsLoaded("spatial")) { - // Spatial extension is not loaded, we cant convert anyway - return false; - } return true; } -LogicalType GeoParquetFileMetadata::GeometryType() { - auto blob_type = LogicalType(LogicalTypeId::BLOB); - blob_type.SetAlias("GEOMETRY"); - return blob_type; -} - const unordered_map &GeoParquetFileMetadata::GetColumnMeta() const { return geometry_columns; } -unique_ptr GeoParquetFileMetadata::CreateColumnReader(ParquetReader &reader, - const ParquetColumnSchema &schema, - ClientContext &context) { - - // Get the catalog - auto &catalog = Catalog::GetSystemCatalog(context); +unique_ptr GeometryColumnReader::Create(ParquetReader &reader, const ParquetColumnSchema &schema, + ClientContext &context) { + D_ASSERT(schema.type.id() == LogicalTypeId::GEOMETRY); + D_ASSERT(schema.children.size() == 1 && schema.children[0].type.id() == LogicalTypeId::BLOB); - // WKB encoding - if (schema.children[0].type.id() == LogicalTypeId::BLOB) { - // Look for a conversion function in the catalog - auto &conversion_func_set = - catalog.GetEntry(context, DEFAULT_SCHEMA, "st_geomfromwkb"); - auto conversion_func = conversion_func_set.functions.GetFunctionByArguments(context, {LogicalType::BLOB}); + // Make a string reader for the underlying WKB data + auto string_reader = make_uniq(reader, schema.children[0]); - // Create a bound function call expression - auto args = vector>(); - args.push_back(std::move(make_uniq(LogicalType::BLOB, 0))); - auto expr = - make_uniq(conversion_func.return_type, conversion_func, std::move(args), nullptr); - - // Create a child reader - auto child_reader = ColumnReader::CreateReader(reader, schema.children[0]); - - // Create an expression reader that applies the conversion function to the child reader - return make_uniq(context, std::move(child_reader), std::move(expr), schema); - } + // Wrap the string reader in a geometry reader + auto args = vector>(); + auto ref = make_uniq_base(LogicalTypeId::BLOB, 0); + args.push_back(std::move(ref)); - // Otherwise, unrecognized encoding - throw NotImplementedException("Unsupported geometry encoding"); + // TODO: Pass the actual target type here so we get the CRS information too + auto func = StGeomfromwkbFun::GetFunction(); + func.name = "ST_GeomFromWKB"; + auto expr = make_uniq_base(schema.type, func, std::move(args), nullptr); + return make_uniq(context, std::move(string_reader), std::move(expr), schema); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 2f34efae2..9fe14688f 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -46,23 +46,23 @@ enum class ParquetMetadataOperatorType : uint8_t { SCHEMA, KEY_VALUE_META_DATA, FILE_META_DATA, - BLOOM_PROBE + BLOOM_PROBE, + FULL_METADATA }; class ParquetMetadataFileProcessor { public: ParquetMetadataFileProcessor() = default; virtual ~ParquetMetadataFileProcessor() = default; - void Initialize(ClientContext &context, OpenFileInfo &file_info) { - ParquetOptions parquet_options(context); - reader = make_uniq(context, file_info, parquet_options); + void Initialize(ClientContext &context, ParquetReader &reader) { + InitializeInternal(context, reader); + } + virtual void InitializeInternal(ClientContext &context, ParquetReader &reader) {}; + virtual idx_t TotalRowCount(ParquetReader &reader) = 0; + virtual void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) = 0; + virtual bool ForceFlush() { + return false; } - virtual void InitializeInternal(ClientContext &context) {}; - virtual idx_t TotalRowCount() = 0; - virtual void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) = 0; - -protected: - unique_ptr reader; }; struct ParquetMetaDataBindData; @@ -115,10 +115,20 @@ struct ParquetMetadataGlobalState : public GlobalTableFunctionState { }; struct ParquetMetadataLocalState : public LocalTableFunctionState { + unique_ptr reader; unique_ptr processor; bool file_exhausted = true; idx_t row_idx = 0; idx_t total_rows = 0; + + void Initialize(ClientContext &context, OpenFileInfo &file_info) { + ParquetOptions parquet_options(context); + reader = make_uniq(context, file_info, parquet_options); + processor->Initialize(context, *reader); + total_rows = processor->TotalRowCount(*reader); + row_idx = 0; + file_exhausted = false; + } }; template @@ -179,9 +189,9 @@ static Value ParquetElementBoolean(bool value, bool is_iset) { class ParquetRowGroupMetadataProcessor : public ParquetMetadataFileProcessor { public: - void InitializeInternal(ClientContext &context) override; - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; private: vector column_schemas; @@ -334,18 +344,27 @@ static Value ConvertParquetGeoStatsTypes(const duckdb_parquet::GeospatialStatist vector types; types.reserve(stats.geospatial_types.size()); - GeometryKindSet kind_set; + GeometryTypeSet type_set = GeometryTypeSet::Empty(); for (auto &type : stats.geospatial_types) { - kind_set.Add(type); + const auto geom_type = (type % 1000); + const auto vert_type = (type / 1000); + if (geom_type < 1 || geom_type > 7) { + throw InvalidInputException("Unsupported geometry type in Parquet geo metadata"); + } + if (vert_type < 0 || vert_type > 3) { + throw InvalidInputException("Unsupported geometry vertex type in Parquet geo metadata"); + } + type_set.Add(static_cast(geom_type), static_cast(vert_type)); } - for (auto &type_name : kind_set.ToString(true)) { + + for (auto &type_name : type_set.ToString(true)) { types.push_back(Value(type_name)); } return Value::LIST(LogicalType::VARCHAR, types); } -void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context) { - auto meta_data = reader->GetFileMetadata(); +void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); column_schemas.clear(); for (idx_t schema_idx = 0; schema_idx < meta_data->schema.size(); schema_idx++) { auto &schema_element = meta_data->schema[schema_idx]; @@ -353,18 +372,19 @@ void ParquetRowGroupMetadataProcessor::InitializeInternal(ClientContext &context continue; } ParquetColumnSchema column_schema; - column_schema.type = reader->DeriveLogicalType(schema_element, column_schema); + column_schema.type = reader.DeriveLogicalType(schema_element, column_schema); column_schemas.push_back(std::move(column_schema)); } } -idx_t ParquetRowGroupMetadataProcessor::TotalRowCount() { - auto meta_data = reader->GetFileMetadata(); +idx_t ParquetRowGroupMetadataProcessor::TotalRowCount(ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); return meta_data->row_groups.size() * column_schemas.size(); } -void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetRowGroupMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); idx_t col_idx = row_idx % column_schemas.size(); idx_t row_group_idx = row_idx / column_schemas.size(); @@ -377,86 +397,90 @@ void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i auto &column_type = column_schema.type; // file_name - output.SetValue(0, output_idx, reader->file.path); + output[0].get().SetValue(output_idx, reader.file.path); // row_group_id - output.SetValue(1, output_idx, Value::BIGINT(UnsafeNumericCast(row_group_idx))); + output[1].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_group_idx))); // row_group_num_rows - output.SetValue(2, output_idx, Value::BIGINT(row_group.num_rows)); + output[2].get().SetValue(output_idx, Value::BIGINT(row_group.num_rows)); // row_group_num_columns - output.SetValue(3, output_idx, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); + output[3].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); // row_group_bytes - output.SetValue(4, output_idx, Value::BIGINT(row_group.total_byte_size)); + output[4].get().SetValue(output_idx, Value::BIGINT(row_group.total_byte_size)); // column_id - output.SetValue(5, output_idx, Value::BIGINT(UnsafeNumericCast(col_idx))); + output[5].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(col_idx))); // file_offset - output.SetValue(6, output_idx, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); + output[6].get().SetValue(output_idx, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); // num_values - output.SetValue(7, output_idx, Value::BIGINT(col_meta.num_values)); + output[7].get().SetValue(output_idx, Value::BIGINT(col_meta.num_values)); // path_in_schema - output.SetValue(8, output_idx, StringUtil::Join(col_meta.path_in_schema, ", ")); + output[8].get().SetValue(output_idx, StringUtil::Join(col_meta.path_in_schema, ", ")); // type - output.SetValue(9, output_idx, ConvertParquetElementToString(col_meta.type)); + output[9].get().SetValue(output_idx, ConvertParquetElementToString(col_meta.type)); // stats_min - output.SetValue(10, output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.min, stats.min)); + output[10].get().SetValue(output_idx, + ConvertParquetStats(column_type, column_schema, stats.__isset.min, stats.min)); // stats_max - output.SetValue(11, output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.max, stats.max)); + output[11].get().SetValue(output_idx, + ConvertParquetStats(column_type, column_schema, stats.__isset.max, stats.max)); // stats_null_count - output.SetValue(12, output_idx, ParquetElementBigint(stats.null_count, stats.__isset.null_count)); + output[12].get().SetValue(output_idx, ParquetElementBigint(stats.null_count, stats.__isset.null_count)); // stats_distinct_count - output.SetValue(13, output_idx, ParquetElementBigint(stats.distinct_count, stats.__isset.distinct_count)); + output[13].get().SetValue(output_idx, ParquetElementBigint(stats.distinct_count, stats.__isset.distinct_count)); // stats_min_value - output.SetValue(14, output_idx, - ConvertParquetStats(column_type, column_schema, stats.__isset.min_value, stats.min_value)); + output[14].get().SetValue( + output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.min_value, stats.min_value)); // stats_max_value - output.SetValue(15, output_idx, - ConvertParquetStats(column_type, column_schema, stats.__isset.max_value, stats.max_value)); + output[15].get().SetValue( + output_idx, ConvertParquetStats(column_type, column_schema, stats.__isset.max_value, stats.max_value)); // compression - output.SetValue(16, output_idx, ConvertParquetElementToString(col_meta.codec)); + output[16].get().SetValue(output_idx, ConvertParquetElementToString(col_meta.codec)); // encodings vector encoding_string; encoding_string.reserve(col_meta.encodings.size()); for (auto &encoding : col_meta.encodings) { encoding_string.push_back(ConvertParquetElementToString(encoding)); } - output.SetValue(17, output_idx, Value(StringUtil::Join(encoding_string, ", "))); + output[17].get().SetValue(output_idx, Value(StringUtil::Join(encoding_string, ", "))); // index_page_offset - output.SetValue(18, output_idx, - ParquetElementBigint(col_meta.index_page_offset, col_meta.__isset.index_page_offset)); + output[18].get().SetValue(output_idx, + ParquetElementBigint(col_meta.index_page_offset, col_meta.__isset.index_page_offset)); // dictionary_page_offset - output.SetValue(19, output_idx, - ParquetElementBigint(col_meta.dictionary_page_offset, col_meta.__isset.dictionary_page_offset)); + output[19].get().SetValue( + output_idx, ParquetElementBigint(col_meta.dictionary_page_offset, col_meta.__isset.dictionary_page_offset)); // data_page_offset - output.SetValue(20, output_idx, Value::BIGINT(col_meta.data_page_offset)); + output[20].get().SetValue(output_idx, Value::BIGINT(col_meta.data_page_offset)); // total_compressed_size - output.SetValue(21, output_idx, Value::BIGINT(col_meta.total_compressed_size)); + output[21].get().SetValue(output_idx, Value::BIGINT(col_meta.total_compressed_size)); // total_uncompressed_size - output.SetValue(22, output_idx, Value::BIGINT(col_meta.total_uncompressed_size)); + output[22].get().SetValue(output_idx, Value::BIGINT(col_meta.total_uncompressed_size)); // key_value_metadata vector map_keys, map_values; for (auto &entry : col_meta.key_value_metadata) { map_keys.push_back(Value::BLOB_RAW(entry.key)); map_values.push_back(Value::BLOB_RAW(entry.value)); } - output.SetValue(23, output_idx, - Value::MAP(LogicalType::BLOB, LogicalType::BLOB, std::move(map_keys), std::move(map_values))); + output[23].get().SetValue( + output_idx, Value::MAP(LogicalType::BLOB, LogicalType::BLOB, std::move(map_keys), std::move(map_values))); // bloom_filter_offset - output.SetValue(24, output_idx, - ParquetElementBigint(col_meta.bloom_filter_offset, col_meta.__isset.bloom_filter_offset)); + output[24].get().SetValue(output_idx, + ParquetElementBigint(col_meta.bloom_filter_offset, col_meta.__isset.bloom_filter_offset)); // bloom_filter_length - output.SetValue(25, output_idx, - ParquetElementBigint(col_meta.bloom_filter_length, col_meta.__isset.bloom_filter_length)); + output[25].get().SetValue(output_idx, + ParquetElementBigint(col_meta.bloom_filter_length, col_meta.__isset.bloom_filter_length)); // min_is_exact - output.SetValue(26, output_idx, ParquetElementBoolean(stats.is_min_value_exact, stats.__isset.is_min_value_exact)); + output[26].get().SetValue(output_idx, + ParquetElementBoolean(stats.is_min_value_exact, stats.__isset.is_min_value_exact)); // max_is_exact - output.SetValue(27, output_idx, ParquetElementBoolean(stats.is_max_value_exact, stats.__isset.is_max_value_exact)); + output[27].get().SetValue(output_idx, + ParquetElementBoolean(stats.is_max_value_exact, stats.__isset.is_max_value_exact)); // row_group_compressed_bytes - output.SetValue(28, output_idx, - ParquetElementBigint(row_group.total_compressed_size, row_group.__isset.total_compressed_size)); + output[28].get().SetValue( + output_idx, ParquetElementBigint(row_group.total_compressed_size, row_group.__isset.total_compressed_size)); // geo_stats_bbox, LogicalType::STRUCT(...) - output.SetValue(29, output_idx, ConvertParquetGeoStatsBBOX(col_meta.geospatial_statistics)); + output[29].get().SetValue(output_idx, ConvertParquetGeoStatsBBOX(col_meta.geospatial_statistics)); // geo_stats_types, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(30, output_idx, ConvertParquetGeoStatsTypes(col_meta.geospatial_statistics)); + output[30].get().SetValue(output_idx, ConvertParquetGeoStatsTypes(col_meta.geospatial_statistics)); } //===--------------------------------------------------------------------===// @@ -465,8 +489,8 @@ void ParquetRowGroupMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i class ParquetSchemaProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -567,45 +591,46 @@ static Value ParquetLogicalTypeToString(const duckdb_parquet::LogicalType &type, return Value(); } -idx_t ParquetSchemaProcessor::TotalRowCount() { - return reader->GetFileMetadata()->schema.size(); +idx_t ParquetSchemaProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->schema.size(); } -void ParquetSchemaProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetSchemaProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); const auto &column = meta_data->schema[row_idx]; // file_name - output.SetValue(0, output_idx, reader->file.path); + output[0].get().SetValue(output_idx, reader.file.path); // name - output.SetValue(1, output_idx, column.name); + output[1].get().SetValue(output_idx, column.name); // type - output.SetValue(2, output_idx, ParquetElementString(column.type, column.__isset.type)); + output[2].get().SetValue(output_idx, ParquetElementString(column.type, column.__isset.type)); // type_length - output.SetValue(3, output_idx, ParquetElementInteger(column.type_length, column.__isset.type_length)); + output[3].get().SetValue(output_idx, ParquetElementInteger(column.type_length, column.__isset.type_length)); // repetition_type - output.SetValue(4, output_idx, ParquetElementString(column.repetition_type, column.__isset.repetition_type)); + output[4].get().SetValue(output_idx, ParquetElementString(column.repetition_type, column.__isset.repetition_type)); // num_children - output.SetValue(5, output_idx, ParquetElementBigint(column.num_children, column.__isset.num_children)); + output[5].get().SetValue(output_idx, ParquetElementBigint(column.num_children, column.__isset.num_children)); // converted_type - output.SetValue(6, output_idx, ParquetElementString(column.converted_type, column.__isset.converted_type)); + output[6].get().SetValue(output_idx, ParquetElementString(column.converted_type, column.__isset.converted_type)); // scale - output.SetValue(7, output_idx, ParquetElementBigint(column.scale, column.__isset.scale)); + output[7].get().SetValue(output_idx, ParquetElementBigint(column.scale, column.__isset.scale)); // precision - output.SetValue(8, output_idx, ParquetElementBigint(column.precision, column.__isset.precision)); + output[8].get().SetValue(output_idx, ParquetElementBigint(column.precision, column.__isset.precision)); // field_id - output.SetValue(9, output_idx, ParquetElementBigint(column.field_id, column.__isset.field_id)); + output[9].get().SetValue(output_idx, ParquetElementBigint(column.field_id, column.__isset.field_id)); // logical_type - output.SetValue(10, output_idx, ParquetLogicalTypeToString(column.logicalType, column.__isset.logicalType)); + output[10].get().SetValue(output_idx, ParquetLogicalTypeToString(column.logicalType, column.__isset.logicalType)); // duckdb_type ParquetColumnSchema column_schema; Value duckdb_type; if (column.__isset.type) { - duckdb_type = reader->DeriveLogicalType(column, column_schema).ToString(); + duckdb_type = reader.DeriveLogicalType(column, column_schema).ToString(); } - output.SetValue(11, output_idx, duckdb_type); + output[11].get().SetValue(output_idx, duckdb_type); // column_id - output.SetValue(12, output_idx, Value::BIGINT(UnsafeNumericCast(row_idx))); + output[12].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(row_idx))); } //===--------------------------------------------------------------------===// @@ -614,8 +639,8 @@ void ParquetSchemaProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t class ParquetKeyValueMetadataProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -631,17 +656,18 @@ void ParquetMetaDataOperator::BindSchemaGetFileMetadata()->key_value_metadata.size(); +idx_t ParquetKeyValueMetadataProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->key_value_metadata.size(); } -void ParquetKeyValueMetadataProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetKeyValueMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); auto &entry = meta_data->key_value_metadata[row_idx]; - output.SetValue(0, output_idx, Value(reader->file.path)); - output.SetValue(1, output_idx, Value::BLOB_RAW(entry.key)); - output.SetValue(2, output_idx, Value::BLOB_RAW(entry.value)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); + output[1].get().SetValue(output_idx, Value::BLOB_RAW(entry.key)); + output[2].get().SetValue(output_idx, Value::BLOB_RAW(entry.value)); } //===--------------------------------------------------------------------===// @@ -650,8 +676,8 @@ void ParquetKeyValueMetadataProcessor::ReadRow(DataChunk &output, idx_t output_i class ParquetFileMetadataProcessor : public ParquetMetadataFileProcessor { public: - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; }; template <> @@ -685,34 +711,34 @@ void ParquetMetaDataOperator::BindSchemaGetFileMetadata(); +void ParquetFileMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); // file_name - output.SetValue(0, output_idx, Value(reader->file.path)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); // created_by - output.SetValue(1, output_idx, ParquetElementStringVal(meta_data->created_by, meta_data->__isset.created_by)); + output[1].get().SetValue(output_idx, ParquetElementStringVal(meta_data->created_by, meta_data->__isset.created_by)); // num_rows - output.SetValue(2, output_idx, Value::BIGINT(meta_data->num_rows)); + output[2].get().SetValue(output_idx, Value::BIGINT(meta_data->num_rows)); // num_row_groups - output.SetValue(3, output_idx, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); + output[3].get().SetValue(output_idx, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); // format_version - output.SetValue(4, output_idx, Value::BIGINT(meta_data->version)); + output[4].get().SetValue(output_idx, Value::BIGINT(meta_data->version)); // encryption_algorithm - output.SetValue(5, output_idx, - ParquetElementString(meta_data->encryption_algorithm, meta_data->__isset.encryption_algorithm)); + output[5].get().SetValue( + output_idx, ParquetElementString(meta_data->encryption_algorithm, meta_data->__isset.encryption_algorithm)); // footer_signing_key_metadata - output.SetValue(6, output_idx, - ParquetElementStringVal(meta_data->footer_signing_key_metadata, - meta_data->__isset.footer_signing_key_metadata)); + output[6].get().SetValue(output_idx, ParquetElementStringVal(meta_data->footer_signing_key_metadata, + meta_data->__isset.footer_signing_key_metadata)); // file_size_bytes - output.SetValue(7, output_idx, Value::UBIGINT(reader->GetHandle().GetFileSize())); + output[7].get().SetValue(output_idx, Value::UBIGINT(reader.GetHandle().GetFileSize())); // footer_size - output.SetValue(8, output_idx, Value::UBIGINT(reader->metadata->footer_size)); + output[8].get().SetValue(output_idx, Value::UBIGINT(reader.metadata->footer_size)); } //===--------------------------------------------------------------------===// @@ -723,9 +749,9 @@ class ParquetBloomProbeProcessor : public ParquetMetadataFileProcessor { public: ParquetBloomProbeProcessor(const string &probe_column, const Value &probe_value); - void InitializeInternal(ClientContext &context) override; - idx_t TotalRowCount() override; - void ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) override; + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; private: string probe_column_name; @@ -754,34 +780,35 @@ ParquetBloomProbeProcessor::ParquetBloomProbeProcessor(const string &probe_colum : probe_column_name(probe_column), probe_constant(probe_value) { } -void ParquetBloomProbeProcessor::InitializeInternal(ClientContext &context) { +void ParquetBloomProbeProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { probe_column_idx = optional_idx::Invalid(); - for (idx_t column_idx = 0; column_idx < reader->columns.size(); column_idx++) { - if (reader->columns[column_idx].name == probe_column_name) { + for (idx_t column_idx = 0; column_idx < reader.columns.size(); column_idx++) { + if (reader.columns[column_idx].name == probe_column_name) { probe_column_idx = column_idx; break; } } if (!probe_column_idx.IsValid()) { - throw InvalidInputException("Column %s not found in %s", probe_column_name, reader->file.path); + throw InvalidInputException("Column %s not found in %s", probe_column_name, reader.file.path); } - auto transport = duckdb_base_std::make_shared(reader->GetHandle(), false); + auto transport = duckdb_base_std::make_shared(reader.GetHandle(), false); protocol = make_uniq>(std::move(transport)); allocator = &BufferAllocator::Get(context); filter = make_uniq( ExpressionType::COMPARE_EQUAL, - probe_constant.CastAs(context, reader->GetColumns()[probe_column_idx.GetIndex()].type)); + probe_constant.CastAs(context, reader.GetColumns()[probe_column_idx.GetIndex()].type)); } -idx_t ParquetBloomProbeProcessor::TotalRowCount() { - return reader->GetFileMetadata()->row_groups.size(); +idx_t ParquetBloomProbeProcessor::TotalRowCount(ParquetReader &reader) { + return reader.GetFileMetadata()->row_groups.size(); } -void ParquetBloomProbeProcessor::ReadRow(DataChunk &output, idx_t output_idx, idx_t row_idx) { - auto meta_data = reader->GetFileMetadata(); +void ParquetBloomProbeProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + auto meta_data = reader.GetFileMetadata(); auto &row_group = meta_data->row_groups[row_idx]; auto &column = row_group.columns[probe_column_idx.GetIndex()]; @@ -789,9 +816,124 @@ void ParquetBloomProbeProcessor::ReadRow(DataChunk &output, idx_t output_idx, id auto bloom_excludes = ParquetStatisticsUtils::BloomFilterExcludes(*filter, column.meta_data, *protocol, *allocator); - output.SetValue(0, output_idx, Value(reader->file.path)); - output.SetValue(1, output_idx, Value::BIGINT(NumericCast(row_idx))); - output.SetValue(2, output_idx, Value::BOOLEAN(bloom_excludes)); + output[0].get().SetValue(output_idx, Value(reader.file.path)); + output[1].get().SetValue(output_idx, Value::BIGINT(NumericCast(row_idx))); + output[2].get().SetValue(output_idx, Value::BOOLEAN(bloom_excludes)); +} + +//===--------------------------------------------------------------------===// +// Full Metadata +//===--------------------------------------------------------------------===// + +class FullMetadataProcessor : public ParquetMetadataFileProcessor { +public: + FullMetadataProcessor() = default; + + void InitializeInternal(ClientContext &context, ParquetReader &reader) override; + idx_t TotalRowCount(ParquetReader &reader) override; + void ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, ParquetReader &reader) override; + bool ForceFlush() override { + return true; + } + +private: + void PopulateMetadata(ParquetMetadataFileProcessor &processor, Vector &output, idx_t output_idx, + ParquetReader &reader); + + ParquetFileMetadataProcessor file_processor; + ParquetRowGroupMetadataProcessor row_group_processor; + ParquetSchemaProcessor schema_processor; + ParquetKeyValueMetadataProcessor kv_processor; +}; + +void FullMetadataProcessor::PopulateMetadata(ParquetMetadataFileProcessor &processor, Vector &output, idx_t output_idx, + ParquetReader &reader) { + auto count = processor.TotalRowCount(reader); + auto *result_data = FlatVector::GetData(output); + auto &result_struct = ListVector::GetEntry(output); + auto &result_struct_entries = StructVector::GetEntries(result_struct); + + ListVector::SetListSize(output, count); + ListVector::Reserve(output, count); + + result_data[output_idx].offset = 0; + result_data[output_idx].length = count; + + FlatVector::Validity(output).SetValid(output_idx); + + vector> vectors; + for (auto &entry : result_struct_entries) { + vectors.push_back(std::ref(*entry.get())); + entry->SetVectorType(VectorType::FLAT_VECTOR); + auto &validity = FlatVector::Validity(*entry); + validity.Initialize(count); + } + for (idx_t i = 0; i < count; i++) { + processor.ReadRow(vectors, i, i, reader); + } +} + +template <> +void ParquetMetaDataOperator::BindSchema(vector &return_types, + vector &names) { + names.emplace_back("parquet_file_metadata"); + vector file_meta_types; + vector file_meta_names; + ParquetMetaDataOperator::BindSchema(file_meta_types, file_meta_names); + child_list_t file_meta_children; + for (idx_t i = 0; i < file_meta_types.size(); i++) { + file_meta_children.push_back(make_pair(file_meta_names[i], file_meta_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(file_meta_children)))); + + names.emplace_back("parquet_metadata"); + vector row_group_types; + vector row_group_names; + ParquetMetaDataOperator::BindSchema(row_group_types, row_group_names); + child_list_t row_group_children; + for (idx_t i = 0; i < row_group_types.size(); i++) { + row_group_children.push_back(make_pair(row_group_names[i], row_group_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(row_group_children)))); + + names.emplace_back("parquet_schema"); + vector schema_types; + vector schema_names; + ParquetMetaDataOperator::BindSchema(schema_types, schema_names); + child_list_t schema_children; + for (idx_t i = 0; i < schema_types.size(); i++) { + schema_children.push_back(make_pair(schema_names[i], schema_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(schema_children)))); + + names.emplace_back("parquet_kv_metadata"); + vector kv_types; + vector kv_names; + ParquetMetaDataOperator::BindSchema(kv_types, kv_names); + child_list_t kv_children; + for (idx_t i = 0; i < kv_types.size(); i++) { + kv_children.push_back(make_pair(kv_names[i], kv_types[i])); + } + return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(kv_children)))); +} + +void FullMetadataProcessor::InitializeInternal(ClientContext &context, ParquetReader &reader) { + file_processor.Initialize(context, reader); + row_group_processor.Initialize(context, reader); + schema_processor.Initialize(context, reader); + kv_processor.Initialize(context, reader); +} + +idx_t FullMetadataProcessor::TotalRowCount(ParquetReader &reader) { + return 1; +} + +void FullMetadataProcessor::ReadRow(vector> &output, idx_t output_idx, idx_t row_idx, + ParquetReader &reader) { + PopulateMetadata(file_processor, output[0].get(), output_idx, reader); + PopulateMetadata(row_group_processor, output[1].get(), output_idx, reader); + PopulateMetadata(schema_processor, output[2].get(), output_idx, reader); + PopulateMetadata(kv_processor, output[3].get(), output_idx, reader); } //===--------------------------------------------------------------------===// @@ -859,6 +1001,10 @@ unique_ptr ParquetMetaDataOperator::InitLocal(Execution make_uniq(probe_bind_data.probe_column_name, probe_bind_data.probe_constant); break; } + case ParquetMetadataOperatorType::FULL_METADATA: { + res->processor = make_uniq(); + break; + } default: throw InternalException("Unsupported ParquetMetadataOperatorType"); } @@ -872,6 +1018,11 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu idx_t output_count = 0; + vector> output_vectors; + for (idx_t i = 0; i < output.ColumnCount(); i++) { + output_vectors.push_back(std::ref(output.data[i])); + } + while (output_count < STANDARD_VECTOR_SIZE) { // Check if we need a new file if (local_state.file_exhausted) { @@ -880,11 +1031,7 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu break; // No more files to process } - local_state.processor->Initialize(context, next_file); - local_state.processor->InitializeInternal(context); - local_state.file_exhausted = false; - local_state.row_idx = 0; - local_state.total_rows = local_state.processor->TotalRowCount(); + local_state.Initialize(context, next_file); } idx_t left_in_vector = STANDARD_VECTOR_SIZE - output_count; @@ -897,14 +1044,19 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu rows_to_output = left_in_vector; } + output.SetCardinality(output_count + rows_to_output); + for (idx_t i = 0; i < rows_to_output; ++i) { - local_state.processor->ReadRow(output, output_count + i, local_state.row_idx + i); + local_state.processor->ReadRow(output_vectors, output_count + i, local_state.row_idx + i, + *local_state.reader); } output_count += rows_to_output; local_state.row_idx += rows_to_output; - } - output.SetCardinality(output_count); + if (local_state.processor->ForceFlush()) { + break; + } + } } double ParquetMetaDataOperator::Progress(ClientContext &context, const FunctionData *bind_data_p, @@ -957,4 +1109,13 @@ ParquetBloomProbeFunction::ParquetBloomProbeFunction() ParquetMetaDataOperator::InitLocal) { table_scan_progress = ParquetMetaDataOperator::Progress; } + +ParquetFullMetadataFunction::ParquetFullMetadataFunction() + : TableFunction("parquet_full_metadata", {LogicalType::VARCHAR}, + ParquetMetaDataOperator::Function, + ParquetMetaDataOperator::Bind, + ParquetMetaDataOperator::InitGlobal, + ParquetMetaDataOperator::InitLocal) { + table_scan_progress = ParquetMetaDataOperator::Progress; +} } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp index 9617f0c83..160211b69 100644 --- a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp +++ b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp @@ -397,10 +397,6 @@ bool ParquetMultiFileInfo::ParseOption(ClientContext &context, const string &ori options.binary_as_string = BooleanValue::Get(val); return true; } - if (key == "variant_legacy_encoding") { - options.variant_legacy_encoding = BooleanValue::Get(val); - return true; - } if (key == "file_row_number") { options.file_row_number = BooleanValue::Get(val); return true; @@ -575,12 +571,21 @@ void ParquetReader::FinishFile(ClientContext &context, GlobalTableFunctionState gstate.row_group_index = 0; } -void ParquetReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, - LocalTableFunctionState &local_state_p, DataChunk &chunk) { +AsyncResult ParquetReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, + LocalTableFunctionState &local_state_p, DataChunk &chunk) { +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + { + vector> tasks = AsyncResult::GenerateTestTasks(); + if (!tasks.empty()) { + return AsyncResult(std::move(tasks)); + } + } +#endif + auto &gstate = gstate_p.Cast(); auto &local_state = local_state_p.Cast(); local_state.scan_state.op = gstate.op; - Scan(context, local_state.scan_state, chunk); + return Scan(context, local_state.scan_state, chunk); } unique_ptr ParquetMultiFileInfo::Copy() { diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index cad5f3a9b..6ef69a78b 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -5,7 +5,7 @@ #include "column_reader.hpp" #include "duckdb.hpp" #include "reader/expression_column_reader.hpp" -#include "geo_parquet.hpp" +#include "parquet_geometry.hpp" #include "reader/list_column_reader.hpp" #include "parquet_crypto.hpp" #include "parquet_file_metadata_cache.hpp" @@ -30,6 +30,8 @@ #include "duckdb/planner/table_filter_state.hpp" #include "duckdb/common/multi_file/multi_file_reader.hpp" #include "duckdb/logging/log_manager.hpp" +#include "duckdb/common/multi_file/multi_file_column_mapper.hpp" +#include "duckdb/common/encryption_functions.hpp" #include #include @@ -92,7 +94,7 @@ static shared_ptr LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &file_handle, const shared_ptr &encryption_config, const EncryptionUtil &encryption_util, optional_idx footer_size) { - auto file_proto = CreateThriftFileProtocol(QueryContext(context), file_handle, false); + auto file_proto = CreateThriftFileProtocol(context, file_handle, false); auto &transport = reinterpret_cast(*file_proto->getTransport()); auto file_size = transport.GetSize(); if (file_size < 12) { @@ -156,14 +158,25 @@ LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &fi } auto metadata = make_uniq(); + auto crypto_metadata = make_uniq(); + if (footer_encrypted) { - auto crypto_metadata = make_uniq(); crypto_metadata->read(file_proto.get()); + if (crypto_metadata->encryption_algorithm.__isset.AES_GCM_CTR_V1) { throw InvalidInputException("File '%s' is encrypted with AES_GCM_CTR_V1, but only AES_GCM_V1 is supported", file_handle.GetPath()); } - ParquetCrypto::Read(*metadata, *file_proto, encryption_config->GetFooterKey(), encryption_util); + auto file_aad = crypto_metadata->encryption_algorithm.AES_GCM_V1.aad_file_unique; + CryptoMetaData aad_crypto_metadata = CryptoMetaData(allocator); + + if (!file_aad.empty()) { + aad_crypto_metadata.Initialize(file_aad); + aad_crypto_metadata.SetModule(ParquetCrypto::FOOTER); + } + ParquetCrypto::GenerateAdditionalAuthenticatedData(allocator, aad_crypto_metadata); + ParquetCrypto::Read(*metadata, *file_proto, encryption_config->GetFooterKey(), encryption_util, + std::move(aad_crypto_metadata)); } else { metadata->read(file_proto.get()); } @@ -171,10 +184,15 @@ LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &fi // Try to read the GeoParquet metadata (if present) auto geo_metadata = GeoParquetFileMetadata::TryRead(*metadata, context); return make_shared_ptr(std::move(metadata), file_handle, std::move(geo_metadata), - footer_len); + std::move(crypto_metadata), footer_len); } LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, ParquetColumnSchema &schema) const { + return DeriveLogicalType(s_ele, parquet_options, schema); +} + +LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, const ParquetOptions &parquet_options, + ParquetColumnSchema &schema) { // inner node if (s_ele.type == Type::FIXED_LEN_BYTE_ARRAY && !s_ele.__isset.type_length) { throw IOException("FIXED_LEN_BYTE_ARRAY requires length to be set"); @@ -225,10 +243,6 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, Parquet return LogicalType::TIME_TZ; } return LogicalType::TIME; - } else if (s_ele.logicalType.__isset.GEOMETRY) { - return LogicalType::BLOB; - } else if (s_ele.logicalType.__isset.GEOGRAPHY) { - return LogicalType::BLOB; } } if (s_ele.__isset.converted_type) { @@ -396,20 +410,19 @@ LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, Parquet ParquetColumnSchema ParquetReader::ParseColumnSchema(const SchemaElement &s_ele, idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, ParquetColumnSchemaType type) { - ParquetColumnSchema schema(max_define, max_repeat, schema_index, column_index, type); - schema.name = s_ele.name; - schema.type = DeriveLogicalType(s_ele, schema); - return schema; + return ParquetColumnSchema::FromSchemaElement(s_ele, max_define, max_repeat, schema_index, column_index, type, + parquet_options); } unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &context, const vector &indexes, const ParquetColumnSchema &schema) { switch (schema.schema_type) { - case ParquetColumnSchemaType::GEOMETRY: - return GeoParquetFileMetadata::CreateColumnReader(*this, schema, context); case ParquetColumnSchemaType::FILE_ROW_NUMBER: return make_uniq(*this, schema); + case ParquetColumnSchemaType::GEOMETRY: { + return GeometryColumnReader::Create(*this, schema, context); + } case ParquetColumnSchemaType::COLUMN: { if (schema.children.empty()) { // leaf reader @@ -466,8 +479,8 @@ unique_ptr ParquetReader::CreateReader(ClientContext &context) { auto column_id = entry.first; auto &expression = entry.second; auto child_reader = std::move(root_struct_reader.child_readers[column_id]); - auto expr_schema = make_uniq(child_reader->Schema(), expression->return_type, - ParquetColumnSchemaType::EXPRESSION); + auto expr_schema = make_uniq(ParquetColumnSchema::FromParentSchema( + child_reader->Schema(), expression->return_type, ParquetColumnSchemaType::EXPRESSION)); auto expr_reader = make_uniq(context, std::move(child_reader), expression->Copy(), std::move(expr_schema)); root_struct_reader.child_readers[column_id] = std::move(expr_reader); @@ -475,102 +488,41 @@ unique_ptr ParquetReader::CreateReader(ClientContext &context) { return ret; } -ParquetColumnSchema::ParquetColumnSchema(idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, - ParquetColumnSchemaType schema_type) - : ParquetColumnSchema(string(), LogicalTypeId::INVALID, max_define, max_repeat, schema_index, column_index, - schema_type) { -} - -ParquetColumnSchema::ParquetColumnSchema(string name_p, LogicalType type_p, idx_t max_define, idx_t max_repeat, - idx_t schema_index, idx_t column_index, ParquetColumnSchemaType schema_type) - : schema_type(schema_type), name(std::move(name_p)), type(std::move(type_p)), max_define(max_define), - max_repeat(max_repeat), schema_index(schema_index), column_index(column_index) { -} - -ParquetColumnSchema::ParquetColumnSchema(ParquetColumnSchema parent, LogicalType result_type, - ParquetColumnSchemaType schema_type) - : schema_type(schema_type), name(parent.name), type(std::move(result_type)), max_define(parent.max_define), - max_repeat(parent.max_repeat), schema_index(parent.schema_index), column_index(parent.column_index) { - children.push_back(std::move(parent)); -} - -unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, - const ParquetOptions &parquet_options, idx_t row_group_idx_p, - const vector &columns) const { - if (schema_type == ParquetColumnSchemaType::EXPRESSION) { - return nullptr; - } - if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { - auto stats = NumericStats::CreateUnknown(type); - auto &row_groups = file_meta_data.row_groups; - D_ASSERT(row_group_idx_p < row_groups.size()); - idx_t row_group_offset_min = 0; - for (idx_t i = 0; i < row_group_idx_p; i++) { - row_group_offset_min += row_groups[i].num_rows; - } - - NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); - NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + - row_groups[row_group_idx_p].num_rows))); - stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - return stats.ToUnique(); - } - return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns, parquet_options.can_have_nan); -} - -static bool IsVariantType(const SchemaElement &root, const vector &children) { - if (children.size() < 2) { +static bool IsGeometryType(const SchemaElement &s_ele, const ParquetFileMetadataCache &metadata, idx_t depth) { + const auto is_blob = s_ele.__isset.type && s_ele.type == Type::BYTE_ARRAY; + if (!is_blob) { return false; } - auto &child0 = children[0]; - auto &child1 = children[1]; - ParquetColumnSchema const *metadata; - ParquetColumnSchema const *value; - - if (child0.name == "metadata" && child1.name == "value") { - metadata = &child0; - value = &child1; - } else if (child1.name == "metadata" && child0.name == "value") { - metadata = &child1; - value = &child0; - } else { - return false; + // TODO: Handle CRS in the future + const auto is_native_geom = s_ele.__isset.logicalType && s_ele.logicalType.__isset.GEOMETRY; + const auto is_native_geog = s_ele.__isset.logicalType && s_ele.logicalType.__isset.GEOGRAPHY; + if (is_native_geom || is_native_geog) { + return true; } - //! Verify names - if (metadata->name != "metadata") { - return false; - } - if (value->name != "value") { - return false; - } + // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata. + const auto is_at_root = depth == 1; + const auto is_in_gpq_metadata = metadata.geo_metadata && metadata.geo_metadata->IsGeometryColumn(s_ele.name); + const auto is_leaf = s_ele.num_children == 0; + const auto is_geoparquet_geom = is_at_root && is_in_gpq_metadata && is_leaf; - //! Verify types - if (metadata->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { - return false; - } - if (value->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { - return false; - } - if (children.size() == 3) { - auto &typed_value = children[2]; - if (typed_value.name != "typed_value") { - return false; - } - } else if (children.size() != 2) { - return false; + if (is_geoparquet_geom) { + return true; } - return true; + + return false; } ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_define, idx_t max_repeat, idx_t &next_schema_idx, idx_t &next_file_idx, ClientContext &context) { - auto file_meta_data = GetFileMetadata(); D_ASSERT(file_meta_data); - D_ASSERT(next_schema_idx < file_meta_data->schema.size()); + if (next_schema_idx >= file_meta_data->schema.size()) { + throw InvalidInputException("Malformed Parquet schema in file \"%s\": invalid schema index %d", file.path, + next_schema_idx); + } auto &s_ele = file_meta_data->schema[next_schema_idx]; auto this_idx = next_schema_idx; @@ -585,15 +537,24 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d max_repeat++; } - // Check for geoparquet spatial types - if (depth == 1) { - // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata. - // geoarrow types, although geometry columns, are structs and have children and are handled below. - if (metadata->geo_metadata && metadata->geo_metadata->IsGeometryColumn(s_ele.name) && s_ele.num_children == 0) { - auto root_schema = ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx++); - return ParquetColumnSchema(std::move(root_schema), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); - } + // Check for geometry type + if (IsGeometryType(s_ele, *metadata, depth)) { + // Geometries in both GeoParquet and native parquet are stored as a WKB-encoded BLOB. + // Because we don't just want to validate that the WKB encoding is correct, but also transform it into + // little-endian if necessary, we cant just make use of the StringColumnReader without heavily modifying it. + // Therefore, we create a dedicated GEOMETRY parquet column schema type, which wraps the underlying BLOB column. + // This schema type gets instantiated as a ExpressionColumnReader on top of the standard Blob/String reader, + // which performs the WKB validation/transformation using the `ST_GeomFromWKB` function of DuckDB. + // This enables us to also support other geometry encodings (such as GeoArrow geometries) easier in the future. + + // Inner BLOB schema + vector geometry_child; + geometry_child.emplace_back(ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx)); + + // Wrap in geometry schema + return ParquetColumnSchema::FromChildSchemas(s_ele.name, LogicalType::GEOMETRY(), max_define, max_repeat, + this_idx, next_file_idx++, std::move(geometry_child), + ParquetColumnSchemaType::GEOMETRY); } if (s_ele.__isset.num_children && s_ele.num_children > 0) { // inner node @@ -627,9 +588,6 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d const bool is_map = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP; bool is_map_kv = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP_KEY_VALUE; bool is_variant = s_ele.__isset.logicalType && s_ele.logicalType.__isset.VARIANT == true; - if (!is_variant) { - is_variant = parquet_options.variant_legacy_encoding && IsVariantType(s_ele, child_schemas); - } if (!is_map_kv && this_idx > 0) { // check if the parent node of this is a map @@ -647,14 +605,12 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d throw IOException("MAP_KEY_VALUE needs to be repeated"); } auto result_type = LogicalType::MAP(child_schemas[0].type, child_schemas[1].type); - ParquetColumnSchema struct_schema(s_ele.name, ListType::GetChildType(result_type), max_define - 1, - max_repeat - 1, this_idx, next_file_idx); - struct_schema.children = std::move(child_schemas); - - ParquetColumnSchema map_schema(s_ele.name, std::move(result_type), max_define, max_repeat, this_idx, - next_file_idx); - map_schema.children.push_back(std::move(struct_schema)); - return map_schema; + vector map_children; + map_children.emplace_back(ParquetColumnSchema::FromChildSchemas( + s_ele.name, ListType::GetChildType(result_type), max_define - 1, max_repeat - 1, this_idx, + next_file_idx, std::move(child_schemas))); + return ParquetColumnSchema::FromChildSchemas(s_ele.name, result_type, max_define, max_repeat, this_idx, + next_file_idx, std::move(map_children)); } ParquetColumnSchema result; if (child_schemas.size() > 1 || (!is_list && !is_map && !is_repeated)) { @@ -665,17 +621,14 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d LogicalType result_type; if (is_variant) { - result_type = LogicalType::JSON(); + result_type = LogicalType::VARIANT(); } else { result_type = LogicalType::STRUCT(std::move(struct_types)); } - ParquetColumnSchema struct_schema(s_ele.name, std::move(result_type), max_define, max_repeat, this_idx, - next_file_idx); - struct_schema.children = std::move(child_schemas); - if (is_variant) { - struct_schema.schema_type = ParquetColumnSchemaType::VARIANT; - } - result = std::move(struct_schema); + ParquetColumnSchemaType schema_type = + is_variant ? ParquetColumnSchemaType::VARIANT : ParquetColumnSchemaType::COLUMN; + result = ParquetColumnSchema::FromChildSchemas(s_ele.name, result_type, max_define, max_repeat, this_idx, + next_file_idx, std::move(child_schemas), schema_type); } else { // if we have a struct with only a single type, pull up result = std::move(child_schemas[0]); @@ -683,10 +636,9 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d } if (is_repeated) { auto list_type = LogicalType::LIST(result.type); - ParquetColumnSchema list_schema(s_ele.name, std::move(list_type), max_define, max_repeat, this_idx, - next_file_idx); - list_schema.children.push_back(std::move(result)); - result = std::move(list_schema); + vector list_child = {std::move(result)}; + result = ParquetColumnSchema::FromChildSchemas(s_ele.name, std::move(list_type), max_define, max_repeat, + this_idx, next_file_idx, std::move(list_child)); } result.parent_schema_index = this_idx; return result; @@ -699,17 +651,9 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d auto result = ParseColumnSchema(s_ele, max_define, max_repeat, this_idx, next_file_idx++); if (s_ele.repetition_type == FieldRepetitionType::REPEATED) { auto list_type = LogicalType::LIST(result.type); - ParquetColumnSchema list_schema(s_ele.name, std::move(list_type), max_define, max_repeat, this_idx, - next_file_idx); - list_schema.children.push_back(std::move(result)); - return list_schema; - } - - // Convert to geometry type if possible - if (s_ele.__isset.logicalType && (s_ele.logicalType.__isset.GEOMETRY || s_ele.logicalType.__isset.GEOGRAPHY) && - GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - return ParquetColumnSchema(std::move(result), GeoParquetFileMetadata::GeometryType(), - ParquetColumnSchemaType::GEOMETRY); + vector list_child = {std::move(result)}; + return ParquetColumnSchema::FromChildSchemas(s_ele.name, std::move(list_type), max_define, max_repeat, + this_idx, next_file_idx, std::move(list_child)); } return result; @@ -717,8 +661,7 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d } static ParquetColumnSchema FileRowNumberSchema() { - return ParquetColumnSchema("file_row_number", LogicalType::BIGINT, 0, 0, 0, 0, - ParquetColumnSchemaType::FILE_ROW_NUMBER); + return ParquetColumnSchema::FileRowNumber(); } unique_ptr ParquetReader::ParseSchema(ClientContext &context) { @@ -727,23 +670,28 @@ unique_ptr ParquetReader::ParseSchema(ClientContext &contex idx_t next_file_idx = 0; if (file_meta_data->schema.empty()) { - throw IOException("Parquet reader: no schema elements found"); + throw IOException("Failed to read Parquet file \"%s\": no schema elements found", file.path); } if (file_meta_data->schema[0].num_children == 0) { - throw IOException("Parquet reader: root schema element has no children"); + throw IOException("Failed to read Parquet file \"%s\": root schema element has no children", file.path); } auto root = ParseSchemaRecursive(0, 0, 0, next_schema_idx, next_file_idx, context); if (root.type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("Root element of Parquet file must be a struct"); + throw InvalidInputException("Failed to read Parquet file \"%s\": Root element of Parquet file must be a struct", + file.path); } D_ASSERT(next_schema_idx == file_meta_data->schema.size() - 1); - D_ASSERT(file_meta_data->row_groups.empty() || next_file_idx == file_meta_data->row_groups[0].columns.size()); + if (!file_meta_data->row_groups.empty() && next_file_idx != file_meta_data->row_groups[0].columns.size()) { + throw InvalidInputException("Failed to read Parquet file \"%s\": row group does not have enough columns", + file.path); + } if (parquet_options.file_row_number) { for (auto &column : root.children) { auto &name = column.name; if (StringUtil::CIEquals(name, "file_row_number")) { - throw BinderException( - "Using file_row_number option on file with column named file_row_number is not supported"); + throw BinderException("Failed to read Parquet file \"%s\": Using file_row_number option on file with " + "column named file_row_number is not supported", + file.path); } } root.children.push_back(FileRowNumberSchema()); @@ -758,7 +706,8 @@ MultiFileColumnDefinition ParquetReader::ParseColumnDefinition(const FileMetaDat result.identifier = Value::INTEGER(MultiFileReader::ORDINAL_FIELD_ID); return result; } - auto &column_schema = file_meta_data.schema[element.schema_index]; + D_ASSERT(element.schema_index.IsValid()); + auto &column_schema = file_meta_data.schema[element.schema_index.GetIndex()]; if (column_schema.__isset.field_id) { result.identifier = Value::INTEGER(column_schema.field_id); @@ -808,9 +757,6 @@ ParquetOptions::ParquetOptions(ClientContext &context) { if (context.TryGetCurrentSetting("binary_as_string", lookup_value)) { binary_as_string = lookup_value.GetValue(); } - if (context.TryGetCurrentSetting("variant_legacy_encoding", lookup_value)) { - variant_legacy_encoding = lookup_value.GetValue(); - } } ParquetColumnDefinition ParquetColumnDefinition::FromSchemaValue(ClientContext &context, const Value &column_value) { @@ -837,7 +783,7 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq shared_ptr metadata_p) : BaseFileReader(std::move(file_p)), fs(CachingFileSystem::Get(context_p)), allocator(BufferAllocator::Get(context_p)), parquet_options(std::move(parquet_options_p)) { - file_handle = fs.OpenFile(QueryContext(context_p), file, FileFlags::FILE_FLAGS_READ); + file_handle = fs.OpenFile(context_p, file, FileFlags::FILE_FLAGS_READ); if (!file_handle->CanSeek()) { throw NotImplementedException( "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " @@ -983,22 +929,37 @@ unique_ptr ParquetReader::ReadStatistics(const ParquetUnionData file_col_idx); } -uint32_t ParquetReader::Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot) { - if (parquet_options.encryption_config) { - return ParquetCrypto::Read(object, iprot, parquet_options.encryption_config->GetFooterKey(), *encryption_util); +string ParquetReader::GetUniqueFileIdentifier(const duckdb_parquet::EncryptionAlgorithm &encryption_algorithm) { + if (encryption_algorithm.__isset.AES_GCM_V1) { + return encryption_algorithm.AES_GCM_V1.aad_file_unique; + } else if (encryption_algorithm.__isset.AES_GCM_CTR_V1) { + throw InternalException("File is encrypted with AES_GCM_CTR_V1, but this is not supported by DuckDB"); } else { - return object.read(&iprot); + throw InternalException("File is encrypted but no encryption algorithm is set"); } } +uint32_t ParquetReader::Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot) { + return object.read(&iprot); +} + +uint32_t ParquetReader::ReadEncrypted(duckdb_apache::thrift::TBase &object, TProtocol &iprot, + CryptoMetaData &aad_crypto_metadata) const { + ParquetCrypto::GenerateAdditionalAuthenticatedData(allocator, aad_crypto_metadata); + return ParquetCrypto::Read(object, iprot, parquet_options.encryption_config->GetFooterKey(), *encryption_util, + aad_crypto_metadata); +} + uint32_t ParquetReader::ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size) { - if (parquet_options.encryption_config) { - return ParquetCrypto::ReadData(iprot, buffer, buffer_size, parquet_options.encryption_config->GetFooterKey(), - *encryption_util); - } else { - return iprot.getTransport()->read(buffer, buffer_size); - } + return iprot.getTransport()->read(buffer, buffer_size); +} + +uint32_t ParquetReader::ReadDataEncrypted(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, + const uint32_t buffer_size, CryptoMetaData &aad_crypto_metadata) const { + ParquetCrypto::GenerateAdditionalAuthenticatedData(allocator, aad_crypto_metadata); + return ParquetCrypto::ReadData(iprot, buffer, buffer_size, parquet_options.encryption_config->GetFooterKey(), + *encryption_util, aad_crypto_metadata); } static idx_t GetRowGroupOffset(ParquetReader &reader, idx_t group_idx) { @@ -1046,7 +1007,6 @@ uint64_t ParquetReader::GetGroupSpan(ParquetReaderScanState &state) { idx_t max_offset = NumericLimits::Minimum(); for (auto &column_chunk : group.columns) { - // Set the min offset idx_t current_min_offset = NumericLimits::Maximum(); if (column_chunk.meta_data.__isset.dictionary_page_offset) { @@ -1144,6 +1104,12 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i auto column_id = column_ids[col_idx]; auto &column_reader = state.root_reader->Cast().GetChildReader(column_id); + // keep track of column and row group ordinal if data is encrypted + if (metadata->crypto_metadata->encryption_algorithm.__isset.AES_GCM_CTR_V1) { + column_reader.InitializeCryptoMetadata(metadata->crypto_metadata->encryption_algorithm, + GetGroup(state).ordinal); + } + if (filters) { auto stats = column_reader.Stats(state.group_idx_list[state.current_group], group.columns); // filters contain output chunk index, not file col idx! @@ -1236,7 +1202,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat state.prefetch_mode = false; } - state.file_handle = fs.OpenFile(QueryContext(context), file, flags); + state.file_handle = fs.OpenFile(context, file, flags); } state.adaptive_filter.reset(); state.scan_filters.clear(); @@ -1247,21 +1213,12 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat } } - state.thrift_file_proto = CreateThriftFileProtocol(QueryContext(context), *state.file_handle, state.prefetch_mode); + state.thrift_file_proto = CreateThriftFileProtocol(context, *state.file_handle, state.prefetch_mode); state.root_reader = CreateReader(context); state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); } -void ParquetReader::Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { - while (ScanInternal(context, state, result)) { - if (result.size() > 0) { - break; - } - result.Reset(); - } -} - void ParquetReader::GetPartitionStats(vector &result) { GetPartitionStats(*GetFileMetadata(), result); } @@ -1279,9 +1236,10 @@ void ParquetReader::GetPartitionStats(const duckdb_parquet::FileMetaData &metada } } -bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { +AsyncResult ParquetReader::Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { + result.Reset(); if (state.finished) { - return false; + return SourceResultType::FINISHED; } // see if we have to switch to the next row group in the parquet file @@ -1295,7 +1253,7 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState if ((idx_t)state.current_group == state.group_idx_list.size()) { state.finished = true; - return false; + return SourceResultType::FINISHED; } // TODO: only need this if we have a deletion vector? @@ -1367,7 +1325,8 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState } } } - return true; + result.Reset(); + return SourceResultType::HAVE_MORE_OUTPUT; } auto scan_count = MinValue(STANDARD_VECTOR_SIZE, GetGroup(state).num_rows - state.offset_in_group); @@ -1375,7 +1334,8 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState if (scan_count == 0) { state.finished = true; - return false; // end of last group, we are done + // end of last group, we are done + return SourceResultType::FINISHED; } auto &deletion_filter = state.root_reader->Reader().deletion_filter; @@ -1440,6 +1400,10 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState } auto &result_vector = result.data[i]; auto &child_reader = root_reader.GetChildReader(file_col_idx); + if (metadata->crypto_metadata->encryption_algorithm.__isset.AES_GCM_V1) { + child_reader.InitializeCryptoMetadata(metadata->crypto_metadata->encryption_algorithm, + GetGroup(state).ordinal); + } child_reader.Select(result.size(), define_ptr, repeat_ptr, result_vector, state.sel, filter_count); } if (scan_count != filter_count) { @@ -1451,6 +1415,10 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState auto file_col_idx = column_ids[col_idx]; auto &result_vector = result.data[i]; auto &child_reader = root_reader.GetChildReader(file_col_idx); + if (metadata->crypto_metadata->encryption_algorithm.__isset.AES_GCM_V1) { + child_reader.InitializeCryptoMetadata(metadata->crypto_metadata->encryption_algorithm, + GetGroup(state).ordinal); + } auto rows_read = child_reader.Read(scan_count, define_ptr, repeat_ptr, result_vector); if (rows_read != scan_count) { throw InvalidInputException("Mismatch in parquet read for column %llu, expected %llu rows, got %llu", @@ -1461,7 +1429,7 @@ bool ParquetReader::ScanInternal(ClientContext &context, ParquetReaderScanState rows_read += scan_count; state.offset_in_group += scan_count; - return true; + return SourceResultType::HAVE_MORE_OUTPUT; } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_shredding.cpp b/src/duckdb/extension/parquet/parquet_shredding.cpp new file mode 100644 index 000000000..b7ed673a8 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_shredding.cpp @@ -0,0 +1,81 @@ +#include "parquet_shredding.hpp" +#include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/type_visitor.hpp" + +namespace duckdb { + +ChildShreddingTypes::ChildShreddingTypes() : types(make_uniq>()) { +} + +ChildShreddingTypes ChildShreddingTypes::Copy() const { + ChildShreddingTypes result; + for (const auto &type : *types) { + result.types->emplace(type.first, type.second.Copy()); + } + return result; +} + +ShreddingType::ShreddingType() : set(false) { +} + +ShreddingType::ShreddingType(const LogicalType &type) : set(true), type(type) { +} + +ShreddingType ShreddingType::Copy() const { + auto result = set ? ShreddingType(type) : ShreddingType(); + result.children = children.Copy(); + return result; +} + +static ShreddingType ConvertShreddingTypeRecursive(const LogicalType &type) { + if (type.id() == LogicalTypeId::VARIANT) { + return ShreddingType(LogicalType(LogicalTypeId::ANY)); + } + if (!type.IsNested()) { + return ShreddingType(type); + } + + switch (type.id()) { + case LogicalTypeId::STRUCT: { + ShreddingType res(type); + auto &children = StructType::GetChildTypes(type); + for (auto &entry : children) { + res.AddChild(entry.first, ConvertShreddingTypeRecursive(entry.second)); + } + return res; + } + case LogicalTypeId::LIST: { + ShreddingType res(type); + const auto &child = ListType::GetChildType(type); + res.AddChild("element", ConvertShreddingTypeRecursive(child)); + return res; + } + default: + break; + } + throw BinderException("VARIANT can only be shredded on LIST/STRUCT/ANY/non-nested type, not %s", type.ToString()); +} + +void ShreddingType::AddChild(const string &name, ShreddingType &&child) { + children.types->emplace(name, std::move(child)); +} + +optional_ptr ShreddingType::GetChild(const string &name) const { + auto it = children.types->find(name); + if (it == children.types->end()) { + return nullptr; + } + return it->second; +} + +ShreddingType ShreddingType::GetShreddingTypes(const Value &val) { + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("SHREDDING value should be of type VARCHAR, a stringified type to use for the column"); + } + auto type_str = val.GetValue(); + auto logical_type = TransformStringToLogicalType(type_str); + + return ConvertShreddingTypeRecursive(logical_type); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index 5f7d93718..27c5daacc 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -322,7 +322,6 @@ Value ParquetStatisticsUtils::ConvertValueInternal(const LogicalType &type, cons unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(const ParquetColumnSchema &schema, const vector &columns, bool can_have_nan) { - // Not supported types auto &type = schema.type; if (type.id() == LogicalTypeId::ARRAY || type.id() == LogicalTypeId::MAP || type.id() == LogicalTypeId::LIST) { @@ -395,26 +394,71 @@ unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(con } break; case LogicalTypeId::VARCHAR: { - auto string_stats = StringStats::CreateEmpty(type); + auto string_stats = StringStats::CreateUnknown(type); if (parquet_stats.__isset.min_value) { StringColumnReader::VerifyString(parquet_stats.min_value.c_str(), parquet_stats.min_value.size(), true); - StringStats::Update(string_stats, parquet_stats.min_value); + StringStats::SetMin(string_stats, parquet_stats.min_value); } else if (parquet_stats.__isset.min) { StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); - StringStats::Update(string_stats, parquet_stats.min); + StringStats::SetMin(string_stats, parquet_stats.min); } if (parquet_stats.__isset.max_value) { StringColumnReader::VerifyString(parquet_stats.max_value.c_str(), parquet_stats.max_value.size(), true); - StringStats::Update(string_stats, parquet_stats.max_value); + StringStats::SetMax(string_stats, parquet_stats.max_value); } else if (parquet_stats.__isset.max) { StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); - StringStats::Update(string_stats, parquet_stats.max); + StringStats::SetMax(string_stats, parquet_stats.max); } - StringStats::SetContainsUnicode(string_stats); - StringStats::ResetMaxStringLength(string_stats); row_group_stats = string_stats.ToUnique(); break; } + case LogicalTypeId::GEOMETRY: { + auto geo_stats = GeometryStats::CreateUnknown(type); + if (column_chunk.meta_data.__isset.geospatial_statistics) { + if (column_chunk.meta_data.geospatial_statistics.__isset.bbox) { + auto &bbox = column_chunk.meta_data.geospatial_statistics.bbox; + auto &stats_bbox = GeometryStats::GetExtent(geo_stats); + + // xmin > xmax is allowed if the geometry crosses the antimeridian, + // but we don't handle this right now + if (bbox.xmin <= bbox.xmax) { + stats_bbox.x_min = bbox.xmin; + stats_bbox.x_max = bbox.xmax; + } + + if (bbox.ymin <= bbox.ymax) { + stats_bbox.y_min = bbox.ymin; + stats_bbox.y_max = bbox.ymax; + } + + if (bbox.__isset.zmin && bbox.__isset.zmax && bbox.zmin <= bbox.zmax) { + stats_bbox.z_min = bbox.zmin; + stats_bbox.z_max = bbox.zmax; + } + + if (bbox.__isset.mmin && bbox.__isset.mmax && bbox.mmin <= bbox.mmax) { + stats_bbox.m_min = bbox.mmin; + stats_bbox.m_max = bbox.mmax; + } + } + if (column_chunk.meta_data.geospatial_statistics.__isset.geospatial_types) { + auto &types = column_chunk.meta_data.geospatial_statistics.geospatial_types; + auto &stats_types = GeometryStats::GetTypes(geo_stats); + + // if types are set but empty, that still means "any type" - so we leave stats_types as-is (unknown) + // otherwise, clear and set to the actual types + + if (!types.empty()) { + stats_types.Clear(); + for (auto &geom_type : types) { + stats_types.AddWKBType(geom_type); + } + } + } + } + row_group_stats = geo_stats.ToUnique(); + break; + } default: // no stats for you break; @@ -580,7 +624,6 @@ bool ParquetStatisticsUtils::BloomFilterExcludes(const TableFilter &duckdb_filte } ParquetBloomFilter::ParquetBloomFilter(idx_t num_entries, double bloom_filter_false_positive_ratio) { - // aim for hit ratio of 0.01% // see http://tfk.mit.edu/pdf/bloom.pdf double f = bloom_filter_false_positive_ratio; diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 2021335ad..35394ebe8 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -3,8 +3,10 @@ #include "duckdb.hpp" #include "mbedtls_wrapper.hpp" #include "parquet_crypto.hpp" +#include "parquet_shredding.hpp" #include "parquet_timestamp.hpp" #include "resizable_buffer.hpp" +#include "duckdb/parser/keyword_helper.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/serializer/buffered_file_writer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -12,6 +14,7 @@ #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/function/table_function.hpp" +#include "duckdb/main/extension_helper.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/connection.hpp" #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" @@ -35,29 +38,6 @@ using duckdb_parquet::PageType; using ParquetRowGroup = duckdb_parquet::RowGroup; using duckdb_parquet::Type; -ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { -} - -ChildFieldIDs ChildFieldIDs::Copy() const { - ChildFieldIDs result; - for (const auto &id : *ids) { - result.ids->emplace(id.first, id.second.Copy()); - } - return result; -} - -FieldID::FieldID() : set(false) { -} - -FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { -} - -FieldID FieldID::Copy() const { - auto result = set ? FieldID(field_id) : FieldID(); - result.child_field_ids = child_field_ids.Copy(); - return result; -} - class MyTransport : public TTransport { public: explicit MyTransport(WriteStream &serializer) : serializer(serializer) { @@ -109,6 +89,7 @@ bool ParquetWriter::TryGetParquetType(const LogicalType &duckdb_type, optional_p case LogicalTypeId::ENUM: case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: + case LogicalTypeId::GEOMETRY: parquet_type = Type::BYTE_ARRAY; break; case LogicalTypeId::TIME: @@ -166,7 +147,8 @@ Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); } -void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele) { +void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry) { if (duckdb_type.IsJSONType()) { schema_ele.converted_type = ConvertedType::JSON; schema_ele.__isset.converted_type = true; @@ -174,13 +156,6 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.__set_JSON(duckdb_parquet::JsonType()); return; } - if (duckdb_type.GetAlias() == "WKB_BLOB") { - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.GEOMETRY = true; - // TODO: Set CRS in the future - schema_ele.logicalType.GEOMETRY.__isset.crs = false; - return; - } switch (duckdb_type.id()) { case LogicalTypeId::TINYINT: schema_ele.converted_type = ConvertedType::INT_8; @@ -285,6 +260,13 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.DECIMAL.precision = schema_ele.precision; schema_ele.logicalType.DECIMAL.scale = schema_ele.scale; break; + case LogicalTypeId::GEOMETRY: + if (allow_geometry) { // Don't set this if we write GeoParquet V1 + schema_ele.__isset.logicalType = true; + schema_ele.logicalType.__isset.GEOMETRY = true; + // TODO: Set CRS in the future + schema_ele.logicalType.GEOMETRY.__isset.crs = false; + } default: break; } @@ -336,9 +318,9 @@ struct ColumnStatsUnifier { bool can_have_nan = false; bool has_nan = false; - unique_ptr geo_stats; + unique_ptr geo_stats; - virtual void UnifyGeoStats(const GeometryStats &other) { + virtual void UnifyGeoStats(const GeometryStatsData &other) { } virtual void UnifyMinMax(const string &new_min, const string &new_max) = 0; @@ -350,27 +332,55 @@ class ParquetStatsAccumulator { vector> stats_unifiers; }; +ParquetWriteTransformData::ParquetWriteTransformData(ClientContext &context, vector types, + vector> expressions_p) + : buffer(context, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR), expressions(std::move(expressions_p)), + executor(context, expressions) { + chunk.Initialize(buffer.GetAllocator(), types); +} + +//! TODO: this doesnt work.. the ParquetWriteTransformData is shared with all threads, the method is stateful, but has +//! no locks Either every local state needs its own copy of this or we need a lock so its used by one thread at a time.. +//! The former has my preference +ColumnDataCollection &ParquetWriteTransformData::ApplyTransform(ColumnDataCollection &input) { + buffer.Reset(); + for (auto &input_chunk : input.Chunks()) { + chunk.Reset(); + executor.Execute(input_chunk, chunk); + buffer.Append(chunk); + } + return buffer; +} + ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, - const vector> &kv_metadata, + ShreddingType shredding_types_p, const vector> &kv_metadata, shared_ptr encryption_config_p, optional_idx dictionary_size_limit_p, idx_t string_dictionary_page_size_limit_p, bool enable_bloom_filters_p, double bloom_filter_false_positive_ratio_p, - int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version) + int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version, + GeoParquetVersion geoparquet_version) : context(context), file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), - encryption_config(std::move(encryption_config_p)), dictionary_size_limit(dictionary_size_limit_p), + shredding_types(std::move(shredding_types_p)), encryption_config(std::move(encryption_config_p)), + dictionary_size_limit(dictionary_size_limit_p), string_dictionary_page_size_limit(string_dictionary_page_size_limit_p), enable_bloom_filters(enable_bloom_filters_p), bloom_filter_false_positive_ratio(bloom_filter_false_positive_ratio_p), compression_level(compression_level_p), - debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), total_written(0), num_row_groups(0) { - + debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), geoparquet_version(geoparquet_version), + total_written(0), num_row_groups(0) { // initialize the file writer writer = make_uniq(fs, file_name.c_str(), FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); if (encryption_config) { auto &config = DBConfig::GetConfig(context); + + // To ensure we can write, we need to autoload httpfs + if (!config.encryption_util || !config.encryption_util->SupportsEncryption()) { + ExtensionHelper::TryAutoLoadExtension(context, "httpfs"); + } + if (config.encryption_util && debug_use_openssl) { // Use OpenSSL encryption_util = config.encryption_util; @@ -390,14 +400,12 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file protocol = tproto_factory.getProtocol(duckdb_base_std::make_shared(*writer)); file_meta_data.num_rows = 0; - file_meta_data.version = 1; + file_meta_data.version = UnsafeNumericCast(parquet_version); file_meta_data.__isset.created_by = true; file_meta_data.created_by = StringUtil::Format("DuckDB version %s (build %s)", DuckDB::LibraryVersion(), DuckDB::SourceID()); - file_meta_data.schema.resize(1); - for (auto &kv_pair : kv_metadata) { duckdb_parquet::KeyValue kv; kv.__set_key(kv_pair.first); @@ -406,34 +414,132 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file file_meta_data.__isset.key_value_metadata = true; } - // populate root schema object - file_meta_data.schema[0].name = "duckdb_schema"; - file_meta_data.schema[0].num_children = NumericCast(sql_types.size()); - file_meta_data.schema[0].__isset.num_children = true; - file_meta_data.schema[0].repetition_type = duckdb_parquet::FieldRepetitionType::REQUIRED; - file_meta_data.schema[0].__isset.repetition_type = true; - auto &unique_names = column_names; VerifyUniqueNames(unique_names); - // construct the child schemas + // V1 GeoParquet stores geometries as blobs, no logical type + auto allow_geometry = geoparquet_version != GeoParquetVersion::V1; + + // construct the column writers + D_ASSERT(sql_types.size() == unique_names.size()); for (idx_t i = 0; i < sql_types.size(); i++) { - auto child_schema = - ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], &field_ids); - column_schemas.push_back(std::move(child_schema)); - } - // now construct the writers based on the schemas - for (auto &child_schema : column_schemas) { vector path_in_schema; - column_writers.push_back( - ColumnWriter::CreateWriterRecursive(context, *this, file_meta_data.schema, child_schema, path_in_schema)); + column_writers.push_back(ColumnWriter::CreateWriterRecursive(context, *this, path_in_schema, sql_types[i], + unique_names[i], allow_geometry, &field_ids, + &shredding_types)); } } ParquetWriter::~ParquetWriter() { } -void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result) { +void ParquetWriter::AnalyzeSchema(ColumnDataCollection &buffer, vector> &column_writers) { + D_ASSERT(buffer.ColumnCount() == column_writers.size()); + vector> states; + bool needs_analyze = false; + lock_guard glock(lock); + + vector column_ids; + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &writer = column_writers[i]; + auto state = writer->AnalyzeSchemaInit(); + if (state) { + needs_analyze = true; + states.push_back(std::move(state)); + column_ids.push_back(i); + } else { + states.push_back(nullptr); + } + } + + if (!needs_analyze) { + return; + } + + for (auto &chunk : buffer.Chunks(column_ids)) { + idx_t index = 0; + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &state = states[i]; + if (!state) { + continue; + } + auto &writer = column_writers[i]; + writer->AnalyzeSchema(*state, chunk.data[index++], chunk.size()); + } + } + + for (idx_t i = 0; i < column_writers.size(); i++) { + auto &writer = column_writers[i]; + auto &state = states[i]; + if (!state) { + continue; + } + writer->AnalyzeSchemaFinalize(*state); + } +} + +void ParquetWriter::InitializePreprocessing(unique_ptr &transform_data) { + if (transform_data) { + return; + } + + vector transformed_types; + vector> transform_expressions; + for (idx_t col_idx = 0; col_idx < column_writers.size(); col_idx++) { + auto &column_writer = *column_writers[col_idx]; + auto &original_type = sql_types[col_idx]; + auto expr = make_uniq(original_type, col_idx); + if (!column_writer.HasTransform()) { + transformed_types.push_back(original_type); + transform_expressions.push_back(std::move(expr)); + continue; + } + transformed_types.push_back(column_writer.TransformedType()); + transform_expressions.push_back(column_writer.TransformExpression(std::move(expr))); + } + transform_data = make_uniq(context, transformed_types, std::move(transform_expressions)); +} + +void ParquetWriter::InitializeSchemaElements() { + //! Populate the schema elements of the parquet file we're writing + lock_guard glock(lock); + if (!file_meta_data.schema.empty()) { + return; + } + // populate root schema object + file_meta_data.schema.resize(1); + file_meta_data.schema[0].name = "duckdb_schema"; + file_meta_data.schema[0].num_children = NumericCast(sql_types.size()); + file_meta_data.schema[0].__isset.num_children = true; + file_meta_data.schema[0].repetition_type = duckdb_parquet::FieldRepetitionType::REQUIRED; + file_meta_data.schema[0].__isset.repetition_type = true; + + for (auto &column_writer : column_writers) { + column_writer->FinalizeSchema(file_meta_data.schema); + } +} + +void ParquetWriter::PrepareRowGroup(ColumnDataCollection &raw_buffer, PreparedRowGroup &result, + unique_ptr &transform_data) { + AnalyzeSchema(raw_buffer, column_writers); + + bool requires_transform = false; + for (auto &writer_p : column_writers) { + auto &writer = *writer_p; + + if (writer.HasTransform()) { + requires_transform = true; + break; + } + } + + reference buffer_ref(raw_buffer); + if (requires_transform) { + InitializePreprocessing(transform_data); + buffer_ref = transform_data->ApplyTransform(raw_buffer); + } + auto &buffer = buffer_ref.get(); + // We write 8 columns at a time so that iterating over ColumnDataCollection is more efficient static constexpr idx_t COLUMNS_PER_PASS = 8; @@ -445,6 +551,8 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro row_group.num_rows = NumericCast(buffer.Count()); row_group.__isset.file_offset = true; + InitializeSchemaElements(); + auto &states = result.states; // iterate over each of the columns of the chunk collection and write them D_ASSERT(buffer.ColumnCount() == column_writers.size()); @@ -459,7 +567,7 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro write_states.emplace_back(col_writers.back().get().InitializeWriteState(row_group)); } - for (auto &chunk : buffer.Chunks({column_ids})) { + for (auto &chunk : buffer.Chunks(column_ids)) { for (idx_t i = 0; i < next; i++) { if (col_writers[i].get().HasAnalyze()) { col_writers[i].get().Analyze(*write_states[i], nullptr, chunk.data[i], chunk.size()); @@ -556,7 +664,7 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { row_group.__isset.total_compressed_size = true; if (encryption_config) { - auto row_group_ordinal = num_row_groups.load(); + const auto row_group_ordinal = file_meta_data.row_groups.size(); if (row_group_ordinal > std::numeric_limits::max()) { throw InvalidInputException("RowGroup ordinal exceeds 32767 when encryption enabled"); } @@ -572,13 +680,21 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { ++num_row_groups; } -void ParquetWriter::Flush(ColumnDataCollection &buffer) { +void ParquetWriter::Flush(ColumnDataCollection &buffer, unique_ptr &transform_data) { if (buffer.Count() == 0) { return; } + // "total_written" is only used for the FILE_SIZE_BYTES flag, and only when threads are writing in parallel. + // We pre-emptively increase it here to try to reduce overshooting when many threads are writing in parallel. + // However, waiting for the exact value (PrepareRowGroup) takes too long, and would cause overshoots to happen. + // So, we guess the compression ratio. We guess 3x, but this will be off depending on the data. + // "total_written" is restored to the exact number of written bytes at the end of FlushRowGroup. + // PhysicalCopyToFile should be reworked to use prepare/flush batch separately for better accuracy. + total_written += buffer.SizeInBytes() / 2; + PreparedRowGroup prepared_row_group; - PrepareRowGroup(buffer, prepared_row_group); + PrepareRowGroup(buffer, prepared_row_group, transform_data); buffer.Reset(); FlushRowGroup(prepared_row_group); @@ -685,15 +801,13 @@ struct BlobStatsUnifier : public BaseStringStatsUnifier { }; struct GeoStatsUnifier : public ColumnStatsUnifier { - - void UnifyGeoStats(const GeometryStats &other) override { + void UnifyGeoStats(const GeometryStatsData &other) override { if (geo_stats) { - geo_stats->bbox.Combine(other.bbox); - geo_stats->types.Combine(other.types); + geo_stats->Merge(other); } else { // Make copy - geo_stats = make_uniq(); - geo_stats->bbox = other.bbox; + geo_stats = make_uniq(); + geo_stats->extent = other.extent; geo_stats->types = other.types; } } @@ -707,17 +821,17 @@ struct GeoStatsUnifier : public ColumnStatsUnifier { return string(); } - const auto &bbox = geo_stats->bbox; + const auto &bbox = geo_stats->extent; const auto &types = geo_stats->types; - const auto bbox_value = Value::STRUCT({{"xmin", bbox.xmin}, - {"xmax", bbox.xmax}, - {"ymin", bbox.ymin}, - {"ymax", bbox.ymax}, - {"zmin", bbox.zmin}, - {"zmax", bbox.zmax}, - {"mmin", bbox.mmin}, - {"mmax", bbox.mmax}}); + const auto bbox_value = Value::STRUCT({{"xmin", bbox.x_min}, + {"xmax", bbox.x_max}, + {"ymin", bbox.y_min}, + {"ymax", bbox.y_max}, + {"zmin", bbox.z_min}, + {"zmax", bbox.z_max}, + {"mmin", bbox.m_min}, + {"mmax", bbox.m_max}}); vector type_strings; for (const auto &type : types.ToString(true)) { @@ -810,11 +924,9 @@ static unique_ptr GetBaseStatsUnifier(const LogicalType &typ } } case LogicalTypeId::BLOB: - if (type.GetAlias() == "WKB_BLOB") { - return make_uniq(); - } else { - return make_uniq(); - } + return make_uniq(); + case LogicalTypeId::GEOMETRY: + return make_uniq(); case LogicalTypeId::VARCHAR: return make_uniq(); case LogicalTypeId::UUID: @@ -826,20 +938,25 @@ static unique_ptr GetBaseStatsUnifier(const LogicalType &typ } } -static void GetStatsUnifier(const ParquetColumnSchema &schema, vector> &unifiers, +static void GetStatsUnifier(const ColumnWriter &column_writer, vector> &unifiers, string base_name = string()) { - if (!base_name.empty()) { - base_name += "."; + auto &schema = column_writer.Schema(); + if (schema.repetition_type != duckdb_parquet::FieldRepetitionType::REPEATED) { + if (!base_name.empty()) { + base_name += "."; + } + base_name += KeywordHelper::WriteQuoted(schema.name, '\"'); } - base_name += KeywordHelper::WriteQuoted(schema.name, '\"'); - if (schema.children.empty()) { + + auto &children = column_writer.ChildWriters(); + if (children.empty()) { auto unifier = GetBaseStatsUnifier(schema.type); unifier->column_name = std::move(base_name); unifiers.push_back(std::move(unifier)); return; } - for (auto &child_schema : schema.children) { - GetStatsUnifier(child_schema, unifiers, base_name); + for (auto &child_writer : children) { + GetStatsUnifier(*child_writer, unifiers, base_name); } } @@ -903,22 +1020,24 @@ void ParquetWriter::GatherWrittenStatistics() { column_stats["has_nan"] = Value::BOOLEAN(stats_unifier->has_nan); } if (stats_unifier->geo_stats) { - const auto &bbox = stats_unifier->geo_stats->bbox; + const auto &bbox = stats_unifier->geo_stats->extent; const auto &types = stats_unifier->geo_stats->types; - column_stats["bbox_xmin"] = Value::DOUBLE(bbox.xmin); - column_stats["bbox_xmax"] = Value::DOUBLE(bbox.xmax); - column_stats["bbox_ymin"] = Value::DOUBLE(bbox.ymin); - column_stats["bbox_ymax"] = Value::DOUBLE(bbox.ymax); + if (bbox.HasXY()) { + column_stats["bbox_xmin"] = Value::DOUBLE(bbox.x_min); + column_stats["bbox_xmax"] = Value::DOUBLE(bbox.x_max); + column_stats["bbox_ymin"] = Value::DOUBLE(bbox.y_min); + column_stats["bbox_ymax"] = Value::DOUBLE(bbox.y_max); - if (bbox.HasZ()) { - column_stats["bbox_zmin"] = Value::DOUBLE(bbox.zmin); - column_stats["bbox_zmax"] = Value::DOUBLE(bbox.zmax); - } + if (bbox.HasZ()) { + column_stats["bbox_zmin"] = Value::DOUBLE(bbox.z_min); + column_stats["bbox_zmax"] = Value::DOUBLE(bbox.z_max); + } - if (bbox.HasM()) { - column_stats["bbox_mmin"] = Value::DOUBLE(bbox.mmin); - column_stats["bbox_mmax"] = Value::DOUBLE(bbox.mmax); + if (bbox.HasM()) { + column_stats["bbox_mmin"] = Value::DOUBLE(bbox.m_min); + column_stats["bbox_mmax"] = Value::DOUBLE(bbox.m_max); + } } if (!types.IsEmpty()) { @@ -934,6 +1053,7 @@ void ParquetWriter::GatherWrittenStatistics() { } void ParquetWriter::Finalize() { + InitializeSchemaElements(); // dump the bloom filters right before footer, not if stuff is encrypted @@ -975,7 +1095,8 @@ void ParquetWriter::Finalize() { } // Add geoparquet metadata to the file metadata - if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { + if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context) && + geoparquet_version != GeoParquetVersion::NONE) { geoparquet_data->Write(file_meta_data); } @@ -1005,7 +1126,7 @@ void ParquetWriter::Finalize() { GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { if (!geoparquet_data) { - geoparquet_data = make_uniq(); + geoparquet_data = make_uniq(geoparquet_version); } return *geoparquet_data; } @@ -1026,7 +1147,7 @@ void ParquetWriter::SetWrittenStatistics(CopyFunctionFileStatistics &written_sta stats_accumulator = make_uniq(); // create the per-column stats unifiers for (auto &column_writer : column_writers) { - GetStatsUnifier(column_writer->Schema(), stats_accumulator->stats_unifiers); + GetStatsUnifier(*column_writer, stats_accumulator->stats_unifiers); } } diff --git a/src/duckdb/extension/parquet/reader/list_column_reader.cpp b/src/duckdb/extension/parquet/reader/list_column_reader.cpp index 0ff1be271..b291e1019 100644 --- a/src/duckdb/extension/parquet/reader/list_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/list_column_reader.cpp @@ -175,7 +175,6 @@ ListColumnReader::ListColumnReader(ParquetReader &reader, const ParquetColumnSch unique_ptr child_column_reader_p) : ColumnReader(reader, schema), child_column_reader(std::move(child_column_reader_p)), read_cache(reader.allocator, ListType::GetChildType(Type())), read_vector(read_cache), overflow_child_count(0) { - child_defines.resize(reader.allocator, STANDARD_VECTOR_SIZE); child_repeats.resize(reader.allocator, STANDARD_VECTOR_SIZE); child_defines_ptr = (uint8_t *)child_defines.ptr; diff --git a/src/duckdb/extension/parquet/reader/string_column_reader.cpp b/src/duckdb/extension/parquet/reader/string_column_reader.cpp index 6b2a3db6d..019abd71a 100644 --- a/src/duckdb/extension/parquet/reader/string_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/string_column_reader.cpp @@ -9,7 +9,7 @@ namespace duckdb { // String Column Reader //===--------------------------------------------------------------------===// StringColumnReader::StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) - : ColumnReader(reader, schema) { + : ColumnReader(reader, schema), string_column_type(GetStringColumnType(Type())) { fixed_width_string_length = 0; if (schema.parquet_type == Type::FIXED_LEN_BYTE_ARRAY) { fixed_width_string_length = schema.type_length; @@ -26,13 +26,26 @@ void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len, co size_t pos; auto utf_type = Utf8Proc::Analyze(str_data, str_len, &reason, &pos); if (utf_type == UnicodeType::INVALID) { - throw InvalidInputException("Invalid string encoding found in Parquet file: value \"" + - Blob::ToString(string_t(str_data, str_len)) + "\" is not valid UTF8!"); + throw InvalidInputException("Invalid string encoding found in Parquet file: value \"%s\" is not valid UTF8!", + Blob::ToString(string_t(str_data, str_len))); } } -void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) { - VerifyString(str_data, str_len, Type().id() == LogicalTypeId::VARCHAR); +void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) const { + switch (string_column_type) { + case StringColumnType::VARCHAR: + VerifyString(str_data, str_len, true); + break; + case StringColumnType::JSON: { + const auto error = StringUtil::ValidateJSON(str_data, str_len); + if (!error.empty()) { + throw InvalidInputException("Invalid JSON found in Parquet file: %s", error); + } + break; + } + default: + break; + } } class ParquetStringVectorBuffer : public VectorBuffer { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp index eacff5501..0388da0b3 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp @@ -15,7 +15,7 @@ static constexpr uint8_t VERSION_MASK = 0xF; static constexpr uint8_t SORTED_STRINGS_MASK = 0x1; static constexpr uint8_t SORTED_STRINGS_SHIFT = 4; static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_MASK = 0x3; -static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 5; +static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 6; static constexpr uint8_t BASIC_TYPE_MASK = 0x3; static constexpr uint8_t VALUE_HEADER_SHIFT = 2; @@ -74,8 +74,8 @@ VariantMetadata::VariantMetadata(const string_t &metadata) : metadata(metadata) const_data_ptr_t ptr = reinterpret_cast(metadata_data + sizeof(uint8_t)); idx_t dictionary_size = ReadVariableLengthLittleEndian(header.offset_size, ptr); - offsets = ptr; - bytes = offsets + ((dictionary_size + 1) * header.offset_size); + auto offsets = ptr; + auto bytes = offsets + ((dictionary_size + 1) * header.offset_size); idx_t last_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); for (idx_t i = 0; i < dictionary_size; i++) { auto next_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); @@ -140,8 +140,7 @@ hugeint_t DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { return result; } -VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { switch (value_metadata.primitive_type) { case VariantPrimitiveType::NULL_TYPE: { @@ -267,8 +266,7 @@ VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &me } } -VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { D_ASSERT(value_metadata.string_size < 64); auto string_data = reinterpret_cast(data); @@ -348,10 +346,10 @@ VariantValue VariantBinaryDecoder::Decode(const VariantMetadata &variant_metadat data++; switch (value_metadata.basic_type) { case VariantBasicType::PRIMITIVE: { - return PrimitiveTypeDecode(variant_metadata, value_metadata, data); + return PrimitiveTypeDecode(value_metadata, data); } case VariantBasicType::SHORT_STRING: { - return ShortStringDecode(variant_metadata, value_metadata, data); + return ShortStringDecode(value_metadata, data); } case VariantBasicType::OBJECT: { return ObjectDecode(variant_metadata, value_metadata, data); diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp index 916e6e2cd..b96304d98 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp @@ -124,7 +124,7 @@ VariantValue ConvertShreddedValue::Convert(hugeint_t val) { template vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &blob, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { UnifiedVectorFormat metadata_format; metadata.ToUnifiedFormat(length, metadata_format); auto metadata_data = metadata_format.GetData(metadata_format); @@ -174,7 +174,12 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b } else { ret[i] = OP::Convert(data[typed_index]); } - } else if (value_validity.RowIsValid(value_index)) { + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -187,7 +192,7 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { D_ASSERT(!typed_value.GetType().IsNested()); vector result; @@ -196,37 +201,37 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! boolean case LogicalTypeId::BOOLEAN: { return ConvertTypedValues, LogicalTypeId::BOOLEAN>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int8 case LogicalTypeId::TINYINT: { return ConvertTypedValues, LogicalTypeId::TINYINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int16 case LogicalTypeId::SMALLINT: { return ConvertTypedValues, LogicalTypeId::SMALLINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int32 case LogicalTypeId::INTEGER: { return ConvertTypedValues, LogicalTypeId::INTEGER>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int64 case LogicalTypeId::BIGINT: { return ConvertTypedValues, LogicalTypeId::BIGINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! float case LogicalTypeId::FLOAT: { return ConvertTypedValues, LogicalTypeId::FLOAT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! double case LogicalTypeId::DOUBLE: { return ConvertTypedValues, LogicalTypeId::DOUBLE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! decimal4/decimal8/decimal16 case LogicalTypeId::DECIMAL: { @@ -234,15 +239,15 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta switch (physical_type) { case PhysicalType::INT32: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT64: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT128: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Decimal with PhysicalType (%s) not implemented for shredded Variant", @@ -252,42 +257,42 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! date case LogicalTypeId::DATE: { return ConvertTypedValues, LogicalTypeId::DATE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! time case LogicalTypeId::TIME: { return ConvertTypedValues, LogicalTypeId::TIME>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestamptz(6) (timestamptz(9) not implemented in DuckDB) case LogicalTypeId::TIMESTAMP_TZ: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_TZ>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(6) case LogicalTypeId::TIMESTAMP: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(9) case LogicalTypeId::TIMESTAMP_NS: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_NS>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! binary case LogicalTypeId::BLOB: { return ConvertTypedValues, LogicalTypeId::BLOB>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! string case LogicalTypeId::VARCHAR: { return ConvertTypedValues, LogicalTypeId::VARCHAR>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! uuid case LogicalTypeId::UUID: { return ConvertTypedValues, LogicalTypeId::UUID>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Variant shredding on type: '%s' is not implemented", type.ToString()); @@ -395,7 +400,7 @@ static VariantValue ConvertPartiallyShreddedObject(vector vector VariantShreddedConversion::ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &type = typed_value.GetType(); D_ASSERT(type.id() == LogicalTypeId::STRUCT); auto &fields = StructType::GetChildTypes(type); @@ -445,7 +450,10 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me if (typed_validity.RowIsValid(typed_index)) { ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); } else { - //! The value on this row is not an object, and guaranteed to be present + if (is_field && !validity.RowIsValid(value_index)) { + //! This object is a field in the parent object, the value is missing, skip it + continue; + } D_ASSERT(validity.RowIsValid(value_index)); auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); @@ -463,7 +471,7 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me vector VariantShreddedConversion::ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &child = ListVector::GetEntry(typed_value); auto list_size = ListVector::GetListSize(typed_value); @@ -489,23 +497,26 @@ vector VariantShreddedConversion::ConvertShreddedArray(Vector &met //! We can be sure that none of the values are binary encoded for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); - //! FIXME: next 4 lines duplicated below auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); } } else { for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); auto value_index = value_format.sel->get_index(i + offset); if (validity.RowIsValid(typed_index)) { - //! FIXME: next 4 lines duplicate auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); - } else if (value_validity.RowIsValid(value_index)) { + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -547,11 +558,11 @@ vector VariantShreddedConversion::Convert(Vector &metadata, Vector auto &type = typed_value->GetType(); vector ret; if (type.id() == LogicalTypeId::STRUCT) { - return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size, is_field); } else if (type.id() == LogicalTypeId::LIST) { - return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size, is_field); } else { - return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size, is_field); } } else { if (is_field) { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_value.cpp b/src/duckdb/extension/parquet/reader/variant/variant_value.cpp deleted file mode 100644 index 0ac213469..000000000 --- a/src/duckdb/extension/parquet/reader/variant/variant_value.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "reader/variant/variant_value.hpp" - -namespace duckdb { - -void VariantValue::AddChild(const string &key, VariantValue &&val) { - D_ASSERT(value_type == VariantValueType::OBJECT); - object_children.emplace(key, std::move(val)); -} - -void VariantValue::AddItem(VariantValue &&val) { - D_ASSERT(value_type == VariantValueType::ARRAY); - array_items.push_back(std::move(val)); -} - -yyjson_mut_val *VariantValue::ToJSON(ClientContext &context, yyjson_mut_doc *doc) const { - switch (value_type) { - case VariantValueType::PRIMITIVE: { - if (primitive_value.IsNull()) { - return yyjson_mut_null(doc); - } - switch (primitive_value.type().id()) { - case LogicalTypeId::BOOLEAN: { - if (primitive_value.GetValue()) { - return yyjson_mut_true(doc); - } else { - return yyjson_mut_false(doc); - } - } - case LogicalTypeId::TINYINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::SMALLINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::INTEGER: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::BIGINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::FLOAT: - return yyjson_mut_real(doc, primitive_value.GetValue()); - case LogicalTypeId::DOUBLE: - return yyjson_mut_real(doc, primitive_value.GetValue()); - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::VARCHAR: { - auto value_str = primitive_value.ToString(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP: { - auto value_str = primitive_value.ToString(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP_NS: { - auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - default: - throw InternalException("Unexpected primitive type: %s", primitive_value.type().ToString()); - } - } - case VariantValueType::OBJECT: { - auto obj = yyjson_mut_obj(doc); - for (const auto &it : object_children) { - auto &key = it.first; - auto value = it.second.ToJSON(context, doc); - yyjson_mut_obj_add_val(doc, obj, key.c_str(), value); - } - return obj; - } - case VariantValueType::ARRAY: { - auto arr = yyjson_mut_arr(doc); - for (auto &item : array_items) { - auto value = item.ToJSON(context, doc); - yyjson_mut_arr_add_val(arr, value); - } - return arr; - } - default: - throw InternalException("Can't serialize this VariantValue type to JSON"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp index 402bcbb07..635bfbbb5 100644 --- a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp @@ -11,7 +11,7 @@ VariantColumnReader::VariantColumnReader(ClientContext &context, ParquetReader & const ParquetColumnSchema &schema, vector> child_readers_p) : ColumnReader(reader, schema), context(context), child_readers(std::move(child_readers_p)) { - D_ASSERT(Type().InternalType() == PhysicalType::VARCHAR); + D_ASSERT(Type().InternalType() == PhysicalType::STRUCT); if (child_readers[0]->Schema().name == "metadata" && child_readers[1]->Schema().name == "value") { metadata_reader_idx = 0; @@ -80,10 +80,7 @@ idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data "The Variant column did not contain the same amount of values for 'metadata' and 'value'"); } - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - vector conversion_result; + vector intermediate; if (typed_value_reader) { auto typed_values = typed_value_reader->Read(num_values, define_out, repeat_out, *group_entries[1]); if (typed_values != value_values) { @@ -91,29 +88,9 @@ idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data "The shredded Variant column did not contain the same amount of values for 'typed_value' and 'value'"); } } - conversion_result = - VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values); - - for (idx_t i = 0; i < conversion_result.size(); i++) { - auto &variant = conversion_result[i]; - if (variant.IsNull()) { - result_validity.SetInvalid(i); - continue; - } - - //! Write the result to a string - VariantDecodeResult decode_result; - decode_result.doc = yyjson_mut_doc_new(nullptr); - auto json_val = variant.ToJSON(context, decode_result.doc); - - size_t len; - decode_result.data = - yyjson_mut_val_write_opts(json_val, YYJSON_WRITE_ALLOW_INF_AND_NAN, nullptr, &len, nullptr); - if (!decode_result.data) { - throw InvalidInputException("Could not serialize the JSON to string, yyjson failed"); - } - result_data[i] = StringVector::AddString(result, decode_result.data, static_cast(len)); - } + intermediate = + VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values, false); + VariantValue::ToVARIANT(intermediate, result); read_count = value_values; return read_count.GetIndex(); diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp index aa5632077..6f12d5d89 100644 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ b/src/duckdb/extension/parquet/serialize_parquet.cpp @@ -7,7 +7,8 @@ #include "duckdb/common/serializer/deserializer.hpp" #include "parquet_reader.hpp" #include "parquet_crypto.hpp" -#include "parquet_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" namespace duckdb { @@ -21,6 +22,16 @@ ChildFieldIDs ChildFieldIDs::Deserialize(Deserializer &deserializer) { return result; } +void ChildShreddingTypes::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "types", types.operator*()); +} + +ChildShreddingTypes ChildShreddingTypes::Deserialize(Deserializer &deserializer) { + ChildShreddingTypes result; + deserializer.ReadPropertyWithDefault>(100, "types", result.types.operator*()); + return result; +} + void FieldID::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(100, "set", set); serializer.WritePropertyWithDefault(101, "field_id", field_id); @@ -89,4 +100,18 @@ ParquetOptionsSerialization ParquetOptionsSerialization::Deserialize(Deserialize return result; } +void ShreddingType::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "set", set); + serializer.WriteProperty(101, "type", type); + serializer.WriteProperty(102, "children", children); +} + +ShreddingType ShreddingType::Deserialize(Deserializer &deserializer) { + ShreddingType result; + deserializer.ReadPropertyWithDefault(100, "set", result.set); + deserializer.ReadProperty(101, "type", result.type); + deserializer.ReadProperty(102, "children", result.children); + return result; +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/array_column_writer.cpp b/src/duckdb/extension/parquet/writer/array_column_writer.cpp index 60284ff28..2a9c9a9d5 100644 --- a/src/duckdb/extension/parquet/writer/array_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/array_column_writer.cpp @@ -6,7 +6,7 @@ void ArrayColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *p auto &state = state_p.Cast(); auto &array_child = ArrayVector::GetEntry(vector); auto array_size = ArrayType::GetSize(vector.GetType()); - child_writer->Analyze(*state.child_state, &state_p, array_child, array_size * count); + GetChildWriter().Analyze(*state.child_state, &state_p, array_child, array_size * count); } void ArrayColumnWriter::WriteArrayState(ListColumnWriterState &state, idx_t array_size, uint16_t first_repeat_level, @@ -35,10 +35,9 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p // write definition levels and repeats // the main difference between this and ListColumnWriter::Prepare is that we need to make sure to write out // repetition levels and definitions for the child elements of the array even if the array itself is NULL. - idx_t start = 0; idx_t vcount = parent ? parent->definition_levels.size() - state.parent_index : count; idx_t vector_index = 0; - for (idx_t i = start; i < vcount; i++) { + for (idx_t i = 0; i < vcount; i++) { idx_t parent_index = state.parent_index + i; if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index]) { WriteArrayState(state, array_size, parent->repetition_levels[parent_index], @@ -63,14 +62,14 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p auto &array_child = ArrayVector::GetEntry(vector); // The elements of a single array should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size, false); + GetChildWriter().Prepare(*state.child_state, &state_p, array_child, count * array_size, false); } void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto array_size = ArrayType::GetSize(vector.GetType()); auto &array_child = ArrayVector::GetEntry(vector); - child_writer->Write(*state.child_state, array_child, count * array_size); + GetChildWriter().Write(*state.child_state, array_child, count * array_size); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp b/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp index 5994a5d27..1157e4bd6 100644 --- a/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp @@ -35,9 +35,9 @@ class BooleanWriterPageState : public ColumnWriterPageState { uint8_t byte_pos = 0; }; -BooleanColumnWriter::BooleanColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +BooleanColumnWriter::BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } unique_ptr BooleanColumnWriter::InitializeStatsState() { diff --git a/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp b/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp index 5f70697b7..4710a9fe6 100644 --- a/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp @@ -66,9 +66,9 @@ class FixedDecimalStatistics : public ColumnWriterStatistics { } }; -FixedDecimalColumnWriter::FixedDecimalColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +FixedDecimalColumnWriter::FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } unique_ptr FixedDecimalColumnWriter::InitializeStatsState() { diff --git a/src/duckdb/extension/parquet/writer/enum_column_writer.cpp b/src/duckdb/extension/parquet/writer/enum_column_writer.cpp index b08d2f566..3ba5d9b28 100644 --- a/src/duckdb/extension/parquet/writer/enum_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/enum_column_writer.cpp @@ -16,9 +16,9 @@ class EnumWriterPageState : public ColumnWriterPageState { bool written_value; }; -EnumColumnWriter::EnumColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path_p, bool can_have_nulls) - : PrimitiveColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { +EnumColumnWriter::EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p) + : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { bit_width = RleBpDecoder::ComputeBitWidth(EnumType::GetSize(Type())); } diff --git a/src/duckdb/extension/parquet/writer/list_column_writer.cpp b/src/duckdb/extension/parquet/writer/list_column_writer.cpp index 8fba00c23..a54017f23 100644 --- a/src/duckdb/extension/parquet/writer/list_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/list_column_writer.cpp @@ -2,25 +2,30 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; + unique_ptr ListColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { auto result = make_uniq(row_group, row_group.columns.size()); - result->child_state = child_writer->InitializeWriteState(row_group); + result->child_state = GetChildWriter().InitializeWriteState(row_group); return std::move(result); } bool ListColumnWriter::HasAnalyze() { - return child_writer->HasAnalyze(); + return GetChildWriter().HasAnalyze(); } void ListColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto &list_child = ListVector::GetEntry(vector); auto list_count = ListVector::GetListSize(vector); - child_writer->Analyze(*state.child_state, &state_p, list_child, list_count); + GetChildWriter().Analyze(*state.child_state, &state_p, list_child, list_count); } void ListColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeAnalyze(*state.child_state); + GetChildWriter().FinalizeAnalyze(*state.child_state); } static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { @@ -114,12 +119,12 @@ void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *pa auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); // The elements of a single list should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, child_list, child_length, false); + GetChildWriter().Prepare(*state.child_state, &state_p, child_list, child_length, false); } void ListColumnWriter::BeginWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->BeginWrite(*state.child_state); + GetChildWriter().BeginWrite(*state.child_state); } void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { @@ -128,12 +133,63 @@ void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t c auto &list_child = ListVector::GetEntry(vector); Vector child_list(list_child); auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); - child_writer->Write(*state.child_state, child_list, child_length); + GetChildWriter().Write(*state.child_state, child_list, child_length); } void ListColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeWrite(*state.child_state); + GetChildWriter().FinalizeWrite(*state.child_state); +} + +ColumnWriter &ListColumnWriter::GetChildWriter() { + D_ASSERT(child_writers.size() == 1); + return *child_writers[0]; +} + +void ListColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto null_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + auto &type = schema.type; + + // set up the two schema elements for the list + // for some reason we only set the converted type in the OPTIONAL element + // first an OPTIONAL element + duckdb_parquet::SchemaElement optional_element; + optional_element.repetition_type = null_type; + optional_element.num_children = 1; + optional_element.converted_type = (type.id() == LogicalTypeId::MAP) ? ConvertedType::MAP : ConvertedType::LIST; + optional_element.__isset.num_children = true; + optional_element.__isset.type = false; + optional_element.__isset.repetition_type = true; + optional_element.__isset.converted_type = true; + optional_element.name = name; + if (field_id.IsValid()) { + optional_element.__isset.field_id = true; + optional_element.field_id = field_id.GetIndex(); + } + schemas.push_back(std::move(optional_element)); + + if (type.id() != LogicalTypeId::MAP) { + duckdb_parquet::SchemaElement repeated_element; + repeated_element.repetition_type = FieldRepetitionType::REPEATED; + repeated_element.__isset.num_children = true; + repeated_element.__isset.type = false; + repeated_element.__isset.repetition_type = true; + repeated_element.num_children = 1; + repeated_element.name = "list"; + schemas.push_back(std::move(repeated_element)); + } else { + //! When we're describing a MAP, we skip the dummy "list" element + //! Instead, the "key_value" struct will be marked as REPEATED + D_ASSERT(GetChildWriter().Schema().repetition_type == FieldRepetitionType::REPEATED); + } + GetChildWriter().FinalizeSchema(schemas); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index d3ebd7dfc..2cd8921ff 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -7,9 +7,12 @@ namespace duckdb { using duckdb_parquet::Encoding; using duckdb_parquet::PageType; -PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, - vector schema_path, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path), can_have_nulls) { +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_PAGE_SIZE; +constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; + +PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path) + : ColumnWriter(writer, std::move(column_schema), std::move(schema_path)) { } unique_ptr PrimitiveColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { @@ -44,7 +47,7 @@ void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterStat idx_t vcount = parent ? parent->definition_levels.size() - state.definition_levels.size() : count; idx_t parent_index = state.definition_levels.size(); auto &validity = FlatVector::Validity(vector); - HandleRepeatLevels(state, parent, count, MaxRepeat()); + HandleRepeatLevels(state, parent, count); HandleDefineLevels(state, parent, validity, count, MaxDefine(), MaxDefine() - 1); idx_t vector_index = 0; @@ -111,7 +114,7 @@ void PrimitiveColumnWriter::BeginWrite(ColumnWriterState &state_p) { hdr.type = PageType::DATA_PAGE; hdr.__isset.data_page_header = true; - hdr.data_page_header.num_values = UnsafeNumericCast(page_info.row_count); + hdr.data_page_header.num_values = NumericCast(page_info.row_count); hdr.data_page_header.encoding = GetEncoding(state); hdr.data_page_header.definition_level_encoding = Encoding::RLE; hdr.data_page_header.repetition_level_encoding = Encoding::RLE; @@ -304,12 +307,23 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta } if (state.stats_state->HasGeoStats()) { - column_chunk.meta_data.__isset.geospatial_statistics = true; - state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + auto gpq_version = writer.GetGeoParquetVersion(); + + const auto has_real_stats = gpq_version == GeoParquetVersion::NONE || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + const auto has_json_stats = gpq_version == GeoParquetVersion::V1 || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; - // Add the geospatial statistics to the extra GeoParquet metadata - writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, - *state.stats_state->GetGeoStats()); + if (has_real_stats) { + // Write the parquet native geospatial statistics + column_chunk.meta_data.__isset.geospatial_statistics = true; + state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + } + if (has_json_stats) { + // Add the geospatial statistics to the extra GeoParquet metadata + writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, + *state.stats_state->GetGeoStats()); + } } for (const auto &write_info : state.write_info) { @@ -417,4 +431,33 @@ void PrimitiveColumnWriter::WriteDictionary(PrimitiveColumnWriterState &state, u state.write_info.insert(state.write_info.begin(), std::move(write_info)); } +void PrimitiveColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + auto &type = schema.type; + auto allow_geometry = schema.allow_geometry; + + duckdb_parquet::SchemaElement schema_element; + schema_element.type = ParquetWriter::DuckDBTypeToParquetType(type); + schema_element.repetition_type = repetition_type; + schema_element.__isset.num_children = false; + schema_element.__isset.type = true; + schema_element.__isset.repetition_type = true; + schema_element.name = name; + if (field_id.IsValid()) { + schema_element.__isset.field_id = true; + schema_element.field_id = field_id.GetIndex(); + } + ParquetWriter::SetSchemaProperties(type, schema_element, allow_geometry); + schemas.push_back(std::move(schema_element)); + + D_ASSERT(child_writers.empty()); +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp index e65515ad5..a792b736b 100644 --- a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp @@ -2,6 +2,11 @@ namespace duckdb { +using namespace duckdb_parquet; // NOLINT + +using duckdb_parquet::ConvertedType; +using duckdb_parquet::FieldRepetitionType; + class StructColumnWriterState : public ColumnWriterState { public: StructColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) @@ -67,7 +72,7 @@ void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState * parent->is_empty.end()); } } - HandleRepeatLevels(state_p, parent, count, MaxRepeat()); + HandleRepeatLevels(state_p, parent, count); HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1); auto &child_vectors = StructVector::GetEntries(vector); for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { @@ -100,4 +105,33 @@ void StructColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { } } +void StructColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = column_schema; + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + auto &field_id = schema.field_id; + + // set up the schema element for this struct + duckdb_parquet::SchemaElement schema_element; + schema_element.repetition_type = repetition_type; + schema_element.num_children = child_writers.size(); + schema_element.__isset.num_children = true; + schema_element.__isset.type = false; + schema_element.__isset.repetition_type = true; + schema_element.name = name; + if (field_id.IsValid()) { + schema_element.__isset.field_id = true; + schema_element.field_id = field_id.GetIndex(); + } + schemas.push_back(std::move(schema_element)); + + for (auto &child_writer : child_writers) { + child_writer->FinalizeSchema(schemas); + } +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp b/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp new file mode 100644 index 000000000..c7575d09f --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp @@ -0,0 +1,202 @@ +#include "writer/variant_column_writer.hpp" +#include "parquet_writer.hpp" +#include "duckdb/common/types/decimal.hpp" + +namespace duckdb { + +unique_ptr VariantColumnWriter::AnalyzeSchemaInit() { + if (child_writers.size() == 2 && !is_analyzed) { + return make_uniq(); + } + //! Variant is already shredded explicitly, no need to analyze + return nullptr; +} + +static void AnalyzeSchemaInternal(VariantAnalyzeData &state, UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_index) { + if (!variant.RowIsValid(row)) { + state.type_map[static_cast(VariantLogicalType::VARIANT_NULL)]++; + return; + } + + auto type_id = variant.GetTypeId(row, values_index); + state.type_map[static_cast(type_id)]++; + + if (type_id == VariantLogicalType::OBJECT) { + if (!state.object_data) { + state.object_data = make_uniq(); + } + auto &object_data = *state.object_data; + + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto child_values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + auto child_key_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + + auto &key = variant.GetKey(row, child_key_index); + auto &child_state = object_data.fields[key.GetString()]; + AnalyzeSchemaInternal(child_state, variant, row, child_values_index); + } + } else if (type_id == VariantLogicalType::ARRAY) { + if (!state.array_data) { + state.array_data = make_uniq(); + } + auto &array_data = *state.array_data; + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto child_values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + auto &child_state = array_data.child; + AnalyzeSchemaInternal(child_state, variant, row, child_values_index); + } + } else if (type_id == VariantLogicalType::DECIMAL) { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + auto physical_type = decimal_data.GetPhysicalType(); + switch (physical_type) { + case PhysicalType::INT32: + state.decimal_type_map[0]++; + break; + case PhysicalType::INT64: + state.decimal_type_map[1]++; + break; + case PhysicalType::INT128: + state.decimal_type_map[2]++; + break; + default: + break; + } + } else if (type_id == VariantLogicalType::BOOL_FALSE) { + //! Move it to bool_true to have the counts all in one place + state.type_map[static_cast(VariantLogicalType::BOOL_TRUE)]++; + state.type_map[static_cast(VariantLogicalType::BOOL_FALSE)]--; + } +} + +void VariantColumnWriter::AnalyzeSchema(ParquetAnalyzeSchemaState &state_p, Vector &input, idx_t count) { + auto &state = state_p.Cast(); + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(input, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + for (idx_t i = 0; i < count; i++) { + AnalyzeSchemaInternal(state.analyze_data, variant, i, 0); + } +} + +namespace { + +struct ShredAnalysisState { + idx_t highest_count = 0; + LogicalTypeId type_id; + PhysicalType decimal_type; +}; + +} // namespace + +template +static void CheckPrimitive(const VariantAnalyzeData &state, ShredAnalysisState &result) { + auto count = state.type_map[static_cast(VARIANT_TYPE)]; + if (VARIANT_TYPE == VariantLogicalType::DECIMAL) { + if (!count) { + return; + } + auto int32_count = state.decimal_type_map[0]; + if (int32_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT32; + } + auto int64_count = state.decimal_type_map[1]; + if (int64_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT64; + } + auto int128_count = state.decimal_type_map[2]; + if (int128_count > result.highest_count) { + result.type_id = LogicalTypeId::DECIMAL; + result.decimal_type = PhysicalType::INT128; + } + } else { + if (count > result.highest_count) { + result.highest_count = count; + result.type_id = SHREDDED_TYPE; + } + } +} + +static LogicalType ConstructShreddedType(const VariantAnalyzeData &state) { + ShredAnalysisState result; + + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + //! FIXME: It's not enough for decimals to have the same PhysicalType, their width+scale has to match in order to + //! shred on the type. + // CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + CheckPrimitive(state, result); + + auto array_count = state.type_map[static_cast(VariantLogicalType::ARRAY)]; + auto object_count = state.type_map[static_cast(VariantLogicalType::OBJECT)]; + if (array_count > object_count) { + if (array_count > result.highest_count) { + auto &array_data = *state.array_data; + return LogicalType::LIST(ConstructShreddedType(array_data.child)); + } + } else { + if (object_count > result.highest_count) { + auto &object_data = *state.object_data; + + //! TODO: implement some logic to determine which fields are worth shredding, considering the overhead when + //! only 10% of rows make use of the field + child_list_t field_types; + for (auto &field : object_data.fields) { + field_types.emplace_back(field.first, ConstructShreddedType(field.second)); + } + return LogicalType::STRUCT(field_types); + } + } + + if (result.type_id == LogicalTypeId::DECIMAL) { + //! TODO: what should the scale be??? + if (result.decimal_type == PhysicalType::INT32) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } else if (result.decimal_type == PhysicalType::INT64) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } else if (result.decimal_type == PhysicalType::INT128) { + return LogicalType::DECIMAL(DecimalWidth::max, 0); + } + } + return result.type_id; +} + +void VariantColumnWriter::AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState &state_p) { + auto &state = state_p.Cast(); + auto shredded_type = ConstructShreddedType(state.analyze_data); + + auto typed_value = TransformTypedValueRecursive(shredded_type); + is_analyzed = true; + + auto &schema = Schema(); + auto &context = writer.GetContext(); + D_ASSERT(child_writers.size() == 2); + child_writers.pop_back(); + //! Recreate the column writer for 'value' because this is now "optional" + child_writers.push_back(ColumnWriter::CreateWriterRecursive(context, writer, schema_path, LogicalType::BLOB, + "value", false, nullptr, nullptr, schema.max_repeat, + schema.max_define + 1, true)); + child_writers.push_back(ColumnWriter::CreateWriterRecursive(context, writer, schema_path, typed_value, + "typed_value", false, nullptr, nullptr, + schema.max_repeat, schema.max_define + 1, true)); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp new file mode 100644 index 000000000..f7be8c755 --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp @@ -0,0 +1,925 @@ +#include "writer/variant_column_writer.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "reader/variant/variant_binary_decoder.hpp" +#include "parquet_shredding.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" + +namespace duckdb { + +static idx_t CalculateByteLength(idx_t value) { + if (value == 0) { + return 1; + } + auto value_data = reinterpret_cast(&value); + idx_t irrelevant_bytes = 0; + //! Check how many of the most significant bytes are 0 + for (idx_t i = sizeof(idx_t); i > 0 && value_data[i - 1] == 0; i--) { + irrelevant_bytes++; + } + return sizeof(idx_t) - irrelevant_bytes; +} + +static uint8_t EncodeMetadataHeader(idx_t byte_length) { + D_ASSERT(byte_length <= 4); + + uint8_t header_byte = 0; + //! Set 'version' to 1 + header_byte |= static_cast(1); + //! Set 'sorted_strings' to 1 + header_byte |= static_cast(1) << 4; + //! Set 'offset_size_minus_one' to byte_length-1 + header_byte |= (static_cast(byte_length) - 1) << 6; + +#ifdef DEBUG + auto decoded_header = VariantMetadataHeader::FromHeaderByte(header_byte); + D_ASSERT(decoded_header.offset_size == byte_length); +#endif + + return header_byte; +} + +static void CreateMetadata(UnifiedVariantVectorData &variant, Vector &metadata, idx_t count) { + auto &keys = variant.keys; + auto keys_data = variant.keys_data; + + //! NOTE: the parquet variant is limited to a max dictionary size of NumericLimits::Maximum() + //! Whereas we can have NumericLimits::Maximum() *per* string in DuckDB + auto metadata_data = FlatVector::GetData(metadata); + for (idx_t row = 0; row < count; row++) { + uint64_t dictionary_count = 0; + if (variant.RowIsValid(row)) { + auto list_entry = keys_data[keys.sel->get_index(row)]; + dictionary_count = list_entry.length; + } + idx_t dictionary_size = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + auto &key = variant.GetKey(row, i); + dictionary_size += key.GetSize(); + } + if (dictionary_size >= NumericLimits::Maximum()) { + throw InvalidInputException("The total length of the dictionary exceeds a 4 byte value (uint32_t), failed " + "to export VARIANT to Parquet"); + } + + auto byte_length = CalculateByteLength(dictionary_size); + auto total_length = 1 + (byte_length * (dictionary_count + 2)) + dictionary_size; + + metadata_data[row] = StringVector::EmptyString(metadata, total_length); + auto &metadata_blob = metadata_data[row]; + auto metadata_blob_data = metadata_blob.GetDataWriteable(); + + metadata_blob_data[0] = EncodeMetadataHeader(byte_length); + memcpy(metadata_blob_data + 1, reinterpret_cast(&dictionary_count), byte_length); + + auto offset_ptr = metadata_blob_data + 1 + byte_length; + auto string_ptr = metadata_blob_data + 1 + byte_length + ((dictionary_count + 1) * byte_length); + idx_t total_offset = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + memcpy(offset_ptr + (i * byte_length), reinterpret_cast(&total_offset), byte_length); + auto &key = variant.GetKey(row, i); + + memcpy(string_ptr + total_offset, key.GetData(), key.GetSize()); + total_offset += key.GetSize(); + } + memcpy(offset_ptr + (dictionary_count * byte_length), reinterpret_cast(&total_offset), byte_length); + D_ASSERT(offset_ptr + ((dictionary_count + 1) * byte_length) == string_ptr); + D_ASSERT(string_ptr + total_offset == metadata_blob_data + total_length); + metadata_blob.SetSizeAndFinalize(total_length, total_length); + +#ifdef DEBUG + auto decoded_metadata = VariantMetadata(metadata_blob); + D_ASSERT(decoded_metadata.strings.size() == dictionary_count); + for (idx_t i = 0; i < dictionary_count; i++) { + D_ASSERT(decoded_metadata.strings[i] == variant.GetKey(row, i).GetString()); + } +#endif + } +} + +namespace { + +static unordered_set GetVariantType(const LogicalType &type) { + if (type.id() == LogicalTypeId::ANY) { + return {}; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: + return {VariantLogicalType::OBJECT}; + case LogicalTypeId::LIST: + return {VariantLogicalType::ARRAY}; + case LogicalTypeId::BOOLEAN: + return {VariantLogicalType::BOOL_TRUE, VariantLogicalType::BOOL_FALSE}; + case LogicalTypeId::TINYINT: + return {VariantLogicalType::INT8}; + case LogicalTypeId::SMALLINT: + return {VariantLogicalType::INT16}; + case LogicalTypeId::INTEGER: + return {VariantLogicalType::INT32}; + case LogicalTypeId::BIGINT: + return {VariantLogicalType::INT64}; + case LogicalTypeId::FLOAT: + return {VariantLogicalType::FLOAT}; + case LogicalTypeId::DOUBLE: + return {VariantLogicalType::DOUBLE}; + case LogicalTypeId::DECIMAL: + return {VariantLogicalType::DECIMAL}; + case LogicalTypeId::DATE: + return {VariantLogicalType::DATE}; + case LogicalTypeId::TIME: + return {VariantLogicalType::TIME_MICROS}; + case LogicalTypeId::TIMESTAMP_TZ: + return {VariantLogicalType::TIMESTAMP_MICROS_TZ}; + case LogicalTypeId::TIMESTAMP: + return {VariantLogicalType::TIMESTAMP_MICROS}; + case LogicalTypeId::TIMESTAMP_NS: + return {VariantLogicalType::TIMESTAMP_NANOS}; + case LogicalTypeId::BLOB: + return {VariantLogicalType::BLOB}; + case LogicalTypeId::VARCHAR: + return {VariantLogicalType::VARCHAR}; + case LogicalTypeId::UUID: + return {VariantLogicalType::UUID}; + default: + throw BinderException("Type '%s' can't be translated to a VARIANT type", type.ToString()); + } +} + +struct ParquetVariantShreddingState : public VariantShreddingState { +public: + ParquetVariantShreddingState(const LogicalType &type, idx_t total_count) + : VariantShreddingState(type, total_count), variant_types(GetVariantType(type)) { + } + +public: + const unordered_set &GetVariantTypes() override { + return variant_types; + } + +private: + unordered_set variant_types; +}; + +struct ParquetVariantShredding : public VariantShredding { + void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) override; +}; + +} // namespace + +vector GetChildIndices(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + optional_ptr shredding_state) { + vector child_indices; + if (!shredding_state || shredding_state->type.id() != LogicalTypeId::STRUCT) { + for (idx_t i = 0; i < nested_data.child_count; i++) { + child_indices.push_back(i); + } + return child_indices; + } + //! FIXME: The variant spec says that field names should be case-sensitive, not insensitive + case_insensitive_string_set_t shredded_fields = shredding_state->ObjectFields(); + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto &key = variant.GetKey(row, keys_index); + + if (shredded_fields.count(key)) { + //! This field is shredded on, omit it from the value + continue; + } + child_indices.push_back(i); + } + return child_indices; +} + +static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + vector &offsets, optional_ptr shredding_state) { + idx_t total_size = 0; + //! Every value has at least a value header + total_size++; + + idx_t offset_size = offsets.size(); + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + switch (type_id) { + case VariantLogicalType::OBJECT: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! Calculate value and key offsets for all children + idx_t total_offset = 0; + uint32_t highest_keys_index = 0; + + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + //! All fields of the object are shredded, omit the object entirely + return 0; + } + + auto num_elements = child_indices.size(); + offsets.resize(offset_size + num_elements + 1); + + for (idx_t entry = 0; entry < child_indices.size(); entry++) { + auto i = child_indices[entry]; + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + entry] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + offsets[offset_size + num_elements] = total_offset; + + //! Calculate the sizes for the objects value data + auto field_id_size = CalculateByteLength(highest_keys_index); + auto field_offset_size = CalculateByteLength(total_offset); + const bool is_large = num_elements > NumericLimits::Maximum(); + + //! Now add the sizes for the objects value data + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += num_elements * field_id_size; + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::ARRAY: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + idx_t total_offset = 0; + offsets.resize(offset_size + nested_data.child_count + 1); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + i] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); + } + offsets[offset_size + nested_data.child_count] = total_offset; + + auto field_offset_size = CalculateByteLength(total_offset); + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + total_size += string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_value.GetSize() > 64) { + //! Save as regular string value + total_size += sizeof(uint32_t); + } + break; + } + case VariantLogicalType::VARIANT_NULL: + case VariantLogicalType::BOOL_TRUE: + case VariantLogicalType::BOOL_FALSE: + break; + case VariantLogicalType::INT8: + total_size += sizeof(uint8_t); + break; + case VariantLogicalType::INT16: + total_size += sizeof(uint16_t); + break; + case VariantLogicalType::INT32: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::INT64: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::FLOAT: + total_size += sizeof(float); + break; + case VariantLogicalType::DOUBLE: + total_size += sizeof(double); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + total_size += 1; + if (decimal_data.width <= 9) { + total_size += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + total_size += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + total_size += sizeof(uhugeint_t); + } else { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } + break; + } + case VariantLogicalType::UUID: + total_size += sizeof(uhugeint_t); + break; + case VariantLogicalType::DATE: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::TIME_MICROS: + case VariantLogicalType::TIMESTAMP_MICROS: + case VariantLogicalType::TIMESTAMP_NANOS: + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } + + return total_size; +} + +template +void WritePrimitiveTypeHeader(data_ptr_t &value_data) { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::PRIMITIVE); + value_header |= static_cast(TYPE_ID) << 2; + + *value_data = value_header; + value_data++; +} + +template +void CopySimplePrimitiveData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, + uint32_t values_index) { + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + memcpy(value_data, ptr, sizeof(T)); + value_data += sizeof(T); +} + +void CopyUUIDData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, uint32_t values_index) { + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto uuid = Load(ptr); + BaseUUID::ToBlob(uuid, value_data); + value_data += sizeof(uhugeint_t); +} + +static void WritePrimitiveValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + + D_ASSERT(type_id != VariantLogicalType::OBJECT && type_id != VariantLogicalType::ARRAY); + switch (type_id) { + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + auto string_size = string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_size > 64) { + if (type_id == VariantLogicalType::BLOB) { + WritePrimitiveTypeHeader(value_data); + } else { + WritePrimitiveTypeHeader(value_data); + } + Store(string_size, value_data); + value_data += sizeof(uint32_t); + } else { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::SHORT_STRING); + value_header |= static_cast(string_size) << 2; + + *value_data = value_header; + value_data++; + } + memcpy(value_data, reinterpret_cast(string_value.GetData()), string_size); + value_data += string_size; + break; + } + case VariantLogicalType::VARIANT_NULL: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_TRUE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_FALSE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::INT8: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT16: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT32: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT64: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::FLOAT: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DOUBLE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::UUID: + WritePrimitiveTypeHeader(value_data); + CopyUUIDData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DATE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIME_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_NANOS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + + if (decimal_data.width <= 4 || decimal_data.width > 38) { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } else if (decimal_data.width <= 9) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int32_t)); + value_data += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int64_t)); + value_data += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(hugeint_t)); + value_data += sizeof(hugeint_t); + } else { + throw InternalException( + "Uncovered VARIANT(DECIMAL) -> Parquet VARIANT conversion for type 'DECIMAL(%d, %d)'", + decimal_data.width, decimal_data.scale); + } + break; + } + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } +} + +static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index, + optional_ptr shredding_state) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + if (type_id == VariantLogicalType::OBJECT) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Object value header -- + + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + throw InternalException( + "The entire should be omitted, should have been handled by the Analyze step already"); + } + auto num_elements = child_indices.size(); + + //! Determine the 'field_id_size' + uint32_t highest_keys_index = 0; + for (auto &i : child_indices) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + auto field_id_size = CalculateByteLength(highest_keys_index); + + uint32_t last_offset = 0; + if (num_elements) { + last_offset = offsets[offset_index + num_elements]; + } + offset_index += num_elements + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::OBJECT); + value_header |= static_cast(is_large) << 6; + value_header |= (static_cast(field_id_size) - 1) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto object_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(object_value_header.basic_type == VariantBasicType::OBJECT); + D_ASSERT(object_value_header.is_large == is_large); + D_ASSERT(object_value_header.field_offset_size == field_offset_size); + D_ASSERT(object_value_header.field_id_size == field_id_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_id' entries + for (auto &i : child_indices) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + memcpy(value_data, reinterpret_cast(&keys_index), field_id_size); + value_data += field_id_size; + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (auto &i : child_indices) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else if (type_id == VariantLogicalType::ARRAY) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Array value header -- + + uint32_t last_offset = 0; + if (nested_data.child_count) { + last_offset = offsets[offset_index + nested_data.child_count]; + } + offset_index += nested_data.child_count + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::ARRAY); + value_header |= static_cast(is_large) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto array_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(array_value_header.basic_type == VariantBasicType::ARRAY); + D_ASSERT(array_value_header.is_large == is_large); + D_ASSERT(array_value_header.field_offset_size == field_offset_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else { + WritePrimitiveValueData(variant, row, values_index, value_data, offsets, offset_index); + } +} + +static void CreateValues(UnifiedVariantVectorData &variant, Vector &value, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, + optional_ptr shredding_state, idx_t count) { + auto &validity = FlatVector::Validity(value); + auto value_data = FlatVector::GetData(value); + + for (idx_t i = 0; i < count; i++) { + idx_t value_index = 0; + if (value_index_sel) { + value_index = value_index_sel->get_index(i); + } + + idx_t row = i; + if (sel) { + row = sel->get_index(i); + } + + idx_t result_index = i; + if (result_sel) { + result_index = result_sel->get_index(i); + } + + bool is_shredded = false; + if (variant.RowIsValid(row) && shredding_state && shredding_state->ValueIsShredded(variant, row, value_index)) { + shredding_state->SetShredded(row, value_index, result_index); + is_shredded = true; + if (shredding_state->type.id() != LogicalTypeId::STRUCT) { + //! Value is shredded, directly write a NULL to the 'value' if the type is not an OBJECT + //! When the type is OBJECT, all excess fields would still need to be written to the 'value' + validity.SetInvalid(result_index); + continue; + } + } + + //! The (relative) offsets for each value, in the case of nesting + vector offsets; + //! Determine the size of this 'value' blob + idx_t blob_length = AnalyzeValueData(variant, row, value_index, offsets, shredding_state); + if (!blob_length) { + //! This is only allowed to happen for a shredded OBJECT, where there are no excess fields to write for the + //! OBJECT + (void)is_shredded; + D_ASSERT(is_shredded); + validity.SetInvalid(result_index); + continue; + } + value_data[result_index] = StringVector::EmptyString(value, blob_length); + auto &value_blob = value_data[result_index]; + auto value_blob_data = reinterpret_cast(value_blob.GetDataWriteable()); + + idx_t offset_index = 0; + WriteValueData(variant, row, value_index, value_blob_data, offsets, offset_index, shredding_state); + D_ASSERT(data_ptr_cast(value_blob.GetDataWriteable() + blob_length) == value_blob_data); + value_blob.SetSizeAndFinalize(blob_length, blob_length); + } +} + +void ParquetVariantShredding::WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) { + optional_ptr value; + optional_ptr typed_value; + + auto &result_type = result.GetType(); + D_ASSERT(result_type.id() == LogicalTypeId::STRUCT); + auto &child_types = StructType::GetChildTypes(result_type); + auto &child_vectors = StructVector::GetEntries(result); + D_ASSERT(child_types.size() == child_vectors.size()); + for (idx_t i = 0; i < child_types.size(); i++) { + auto &name = child_types[i].first; + if (name == "value") { + value = child_vectors[i].get(); + } else if (name == "typed_value") { + typed_value = child_vectors[i].get(); + } + } + + if (typed_value) { + ParquetVariantShreddingState shredding_state(typed_value->GetType(), count); + CreateValues(variant, *value, sel, value_index_sel, result_sel, &shredding_state, count); + + SelectionVector null_values; + if (shredding_state.count) { + WriteTypedValues(variant, *typed_value, shredding_state.shredded_sel, shredding_state.values_index_sel, + shredding_state.result_sel, shredding_state.count); + //! 'shredding_state.result_sel' will always be a subset of 'result_sel', set the rows not in the subset to + //! NULL + idx_t sel_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto original_index = result_sel ? result_sel->get_index(i) : i; + if (sel_idx < shredding_state.count && shredding_state.result_sel[sel_idx] == original_index) { + sel_idx++; + continue; + } + FlatVector::SetNull(*typed_value, original_index, true); + } + } else { + //! Set all rows of the typed_value to NULL, nothing is shredded on + for (idx_t i = 0; i < count; i++) { + FlatVector::SetNull(*typed_value, result_sel ? result_sel->get_index(i) : i, true); + } + } + } else { + CreateValues(variant, *value, sel, value_index_sel, result_sel, nullptr, count); + } +} + +static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &result) { + // DuckDB Variant: + // - keys = VARCHAR[] + // - children = STRUCT(keys_index UINTEGER, values_index UINTEGER)[] + // - values = STRUCT(type_id UTINYINT, byte_offset UINTEGER)[] + // - data = BLOB + + // Parquet VARIANT: + // - metadata = BLOB + // - value = BLOB + + auto &variant_vec = input.data[0]; + auto count = input.size(); + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + auto &result_vectors = StructVector::GetEntries(result); + auto &metadata = *result_vectors[0]; + CreateMetadata(variant, metadata, count); + + ParquetVariantShredding shredding; + shredding.WriteVariantValues(variant, result, nullptr, nullptr, nullptr, count); + + if (input.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +void VariantColumnWriter::FinalizeSchema(vector &schemas) { + idx_t schema_idx = schemas.size(); + + auto &schema = Schema(); + schema.SetSchemaIndex(schema_idx); + + auto &repetition_type = schema.repetition_type; + auto &name = schema.name; + + // variant group + duckdb_parquet::SchemaElement top_element; + top_element.repetition_type = repetition_type; + top_element.num_children = child_writers.size(); + top_element.logicalType.__isset.VARIANT = true; + top_element.logicalType.VARIANT.__isset.specification_version = true; + top_element.logicalType.VARIANT.specification_version = 1; + top_element.__isset.logicalType = true; + top_element.__isset.num_children = true; + top_element.__isset.repetition_type = true; + top_element.name = name; + schemas.push_back(std::move(top_element)); + + for (auto &child_writer : child_writers) { + child_writer->FinalizeSchema(schemas); + } +} + +LogicalType VariantColumnWriter::TransformTypedValueRecursive(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::STRUCT: { + //! Wrap all fields of the struct in a struct with 'value' and 'typed_value' fields + auto &child_types = StructType::GetChildTypes(type); + child_list_t replaced_types; + for (auto &entry : child_types) { + child_list_t child_children; + child_children.emplace_back("value", LogicalType::BLOB); + if (entry.second.id() != LogicalTypeId::VARIANT) { + child_children.emplace_back("typed_value", TransformTypedValueRecursive(entry.second)); + } + replaced_types.emplace_back(entry.first, LogicalType::STRUCT(child_children)); + } + return LogicalType::STRUCT(replaced_types); + } + case LogicalTypeId::LIST: { + auto &child_type = ListType::GetChildType(type); + child_list_t replaced_types; + replaced_types.emplace_back("value", LogicalType::BLOB); + if (child_type.id() != LogicalTypeId::VARIANT) { + replaced_types.emplace_back("typed_value", TransformTypedValueRecursive(child_type)); + } + return LogicalType::LIST(LogicalType::STRUCT(replaced_types)); + } + case LogicalTypeId::UNION: + case LogicalTypeId::MAP: + case LogicalTypeId::VARIANT: + case LogicalTypeId::ARRAY: + throw BinderException("'%s' can't appear inside the a 'typed_value' shredded type!", type.ToString()); + default: + return type; + } +} + +static LogicalType GetParquetVariantType(optional_ptr shredding = nullptr) { + child_list_t children; + children.emplace_back("metadata", LogicalType::BLOB); + children.emplace_back("value", LogicalType::BLOB); + if (shredding) { + children.emplace_back("typed_value", VariantColumnWriter::TransformTypedValueRecursive(*shredding)); + } + auto res = LogicalType::STRUCT(std::move(children)); + res.SetAlias("PARQUET_VARIANT"); + return res; +} + +static unique_ptr BindTransform(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.empty()) { + return nullptr; + } + auto type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); + + if (arguments.size() == 2) { + auto &shredding = *arguments[1]; + auto expr_return_type = ExpressionBinder::GetExpressionReturnType(shredding); + expr_return_type = LogicalType::NormalizeType(expr_return_type); + if (expr_return_type.id() != LogicalTypeId::VARCHAR) { + throw BinderException("Optional second argument 'shredding' has to be of type VARCHAR, i.e: " + "'STRUCT(my_field BOOLEAN)', found type: '%s' instead", + expr_return_type); + } + if (!shredding.IsFoldable()) { + throw BinderException("Optional second argument 'shredding' has to be a constant expression"); + } + Value type_str = ExpressionExecutor::EvaluateScalar(context, shredding); + if (type_str.IsNull()) { + throw BinderException("Optional second argument 'shredding' can not be NULL"); + } + auto shredded_type = TransformStringToLogicalType(type_str.GetValue()); + bound_function.SetReturnType(GetParquetVariantType(shredded_type)); + } else { + bound_function.SetReturnType(GetParquetVariantType()); + } + + return nullptr; +} + +ScalarFunction VariantColumnWriter::GetTransformFunction() { + ScalarFunction transform("variant_to_parquet_variant", {LogicalType::VARIANT()}, LogicalType::ANY, ToParquetVariant, + BindTransform); + transform.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return transform; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/zstd_file_system.cpp b/src/duckdb/extension/parquet/zstd_file_system.cpp index 3bddf8661..9879ff4e4 100644 --- a/src/duckdb/extension/parquet/zstd_file_system.cpp +++ b/src/duckdb/extension/parquet/zstd_file_system.cpp @@ -28,7 +28,19 @@ ZstdStreamWrapper::~ZstdStreamWrapper() { } try { Close(); - } catch (...) { // NOLINT: swallow exceptions in destructor + } catch (std::exception &ex) { + if (file && file->child_handle) { + // FIXME: Make any log context available here. + ErrorData data(ex); + try { + const auto logger = file->child_handle->logger; + if (logger) { + DUCKDB_LOG_ERROR(logger, "ZstdStreamWrapper::~ZstdStreamWrapper()\t\t" + data.Message()); + } + } catch (...) { // NOLINT + } + } + } catch (...) { // NOLINT } } diff --git a/src/duckdb/generated_extension_loader_package_build.cpp b/src/duckdb/generated_extension_loader_package_build.cpp new file mode 100644 index 000000000..95c8dad03 --- /dev/null +++ b/src/duckdb/generated_extension_loader_package_build.cpp @@ -0,0 +1,59 @@ +#include "core_functions_extension.hpp" +#include "parquet_extension.hpp" +#include "icu_extension.hpp" +#include "json_extension.hpp" +#include "jemalloc_extension.hpp" +#include "duckdb/main/extension/generated_extension_loader.hpp" +#include "duckdb/main/extension_helper.hpp" + +namespace duckdb { + +//! Looks through the package_build.py-generated list of extensions that are linked into DuckDB currently to try load +ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string &extension) { + + if (extension=="core_functions") { + db.LoadStaticExtension(); + return ExtensionLoadResult::LOADED_EXTENSION; + } + + if (extension=="parquet") { + db.LoadStaticExtension(); + return ExtensionLoadResult::LOADED_EXTENSION; + } + + if (extension=="icu") { + db.LoadStaticExtension(); + return ExtensionLoadResult::LOADED_EXTENSION; + } + + if (extension=="json") { + db.LoadStaticExtension(); + return ExtensionLoadResult::LOADED_EXTENSION; + } + + if (extension=="jemalloc") { + db.LoadStaticExtension(); + return ExtensionLoadResult::LOADED_EXTENSION; + } + + return ExtensionLoadResult::NOT_LOADED; +} + +vector LinkedExtensions(){ + vector VEC = {"core_functions", "parquet", "icu", "json", "jemalloc" + }; + return VEC; +} + +void ExtensionHelper::LoadAllExtensions(DuckDB &db) { + for (auto& ext_name : LinkedExtensions()) { + LoadExtension(db, ext_name); + } +} + +vector ExtensionHelper::LoadedExtensionTestPaths(){ + vector VEC = { + }; + return VEC; +} +} \ No newline at end of file diff --git a/src/duckdb/src/catalog/catalog.cpp b/src/duckdb/src/catalog/catalog.cpp index 08e27f28f..0dec15f20 100644 --- a/src/duckdb/src/catalog/catalog.cpp +++ b/src/duckdb/src/catalog/catalog.cpp @@ -909,6 +909,22 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con if (if_not_found == OnEntryNotFound::RETURN_NULL) { return {nullptr, nullptr, ErrorData()}; } + + // If we have a specific schema name and no schemas were found, the schema doesn't exist. + // Throw an error about the schema instead of the table + if (schemas.empty() && !lookups.empty() && lookup_info.GetCatalogType() == CatalogType::TABLE_ENTRY) { + string schema_name = lookups[0].schema; + if (!IsInvalidSchema(schema_name)) { + EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema_name, lookup_info.GetErrorContext()); + string relation_name = schema_name + "." + lookup_info.GetEntryName(); + auto except = + CatalogException(schema_lookup.GetErrorContext(), + "Table with name \"%s\" does not exist because schema \"%s\" does not exist.", + relation_name, schema_name); + return {nullptr, nullptr, ErrorData(except)}; + } + } + // Check if the default database is actually attached. CreateMissingEntryException will throw binder exception // otherwise. if (!GetCatalogEntry(context, GetDefaultCatalog(retriever))) { diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp index 7fdc0c3be..8fca4a954 100644 --- a/src/duckdb/src/catalog/catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry.cpp @@ -48,7 +48,7 @@ unique_ptr CatalogEntry::GetInfo() const { } string CatalogEntry::ToSQL() const { - throw InternalException("Unsupported catalog type for ToSQL()"); + throw InternalException({{"catalog_type", CatalogTypeToString(type)}}, "Unsupported catalog type for ToSQL()"); } void CatalogEntry::SetChild(unique_ptr child_p) { diff --git a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp index 25544a343..6d639bef6 100644 --- a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp @@ -3,6 +3,8 @@ namespace duckdb { +constexpr const char *CopyFunctionCatalogEntry::Name; + CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateCopyFunctionInfo &info) : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp index c70984e53..769a06b9f 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp @@ -22,7 +22,6 @@ void DuckIndexEntry::Rollback(CatalogEntry &) { DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, TableCatalogEntry &table_p) : IndexCatalogEntry(catalog, schema, create_info), initial_index_size(0) { - auto &table = table_p.Cast(); auto &storage = table.GetStorage(); info = make_shared_ptr(storage.GetDataTableInfo(), name); diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp index a0f40ce82..33d0db4da 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp @@ -391,7 +391,7 @@ CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { case CatalogType::TYPE_ENTRY: return types; default: - throw InternalException("Unsupported catalog type in schema"); + throw InternalException({{"catalog_type", CatalogTypeToString(type)}}, "Unsupported catalog type in schema"); } } diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index b80204ac0..dfecd0f16 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -29,7 +29,6 @@ namespace duckdb { IndexStorageInfo GetIndexInfo(const IndexConstraintType type, const bool v1_0_0_storage, unique_ptr &info, const idx_t id) { - auto &table_info = info->Cast(); auto constraint_name = EnumUtil::ToString(type) + "_"; auto name = constraint_name + table_info.table + "_" + to_string(id); @@ -44,7 +43,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou shared_ptr inherited_storage) : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), column_dependency_manager(std::move(info.column_dependency_manager)) { - if (storage) { if (!info.indexes.empty()) { storage->SetIndexStorageInfo(std::move(info.indexes)); @@ -55,9 +53,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou // create the physical storage vector column_defs; for (auto &col_def : columns.Physical()) { - if (TypeVisitor::Contains(col_def.Type(), LogicalTypeId::VARIANT)) { - throw NotImplementedException("A table cannot be created from a VARIANT column yet"); - } column_defs.push_back(col_def.Copy()); } storage = make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), @@ -68,7 +63,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou for (idx_t i = 0; i < constraints.size(); i++) { auto &constraint = constraints[i]; if (constraint->type == ConstraintType::UNIQUE) { - // UNIQUE constraint: Create a unique index. auto &unique = constraint->Cast(); IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; @@ -99,7 +93,6 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou auto &bfk = constraint->Cast(); if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - vector column_indexes; for (const auto &physical_index : bfk.info.fk_keys) { auto &col = columns.GetColumn(physical_index); @@ -595,12 +588,24 @@ void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_i auto copy = constraint->Copy(); auto &unique = copy->Cast(); if (unique.HasIndex()) { + // Single-column UNIQUE constraint if (unique.GetIndex() == removed_index) { throw CatalogException( "Cannot drop column \"%s\" because there is a UNIQUE constraint that depends on it", info.removed_column); } unique.SetIndex(adjusted_indices[unique.GetIndex().index]); + } else { + // Multi-column UNIQUE constraint - check if any column matches the one being dropped + for (const auto &col_name : unique.GetColumnNames()) { + if (col_name == info.removed_column) { + // Build constraint string for error message: UNIQUE(col1, col2, ...) + auto constraint_str = "UNIQUE(" + StringUtil::Join(unique.GetColumnNames(), ", ") + ")"; + throw CatalogException( + "Cannot drop column \"%s\" because it is referenced in unique constraint %s", + info.removed_column, constraint_str); + } + } } create_info.constraints.push_back(std::move(copy)); break; @@ -881,31 +886,20 @@ unique_ptr DuckTableEntry::RenameField(ClientContext &context, Ren } unique_ptr DuckTableEntry::SetDefault(ClientContext &context, SetDefaultInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; auto default_idx = GetColumnIndex(info.column_name); if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { throw CatalogException("Cannot SET DEFAULT for rowid column"); } - // Copy all the columns, changing the value of the one that was specified by 'column_name' - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (default_idx == col.Logical()) { - // set the default value of this column - if (copy.Generated()) { - throw BinderException("Cannot SET DEFAULT for generated column \"%s\"", col.Name()); - } - copy.SetDefaultValue(info.expression ? info.expression->Copy() : nullptr); - } - create_info->columns.AddColumn(std::move(copy)); - } - // Copy all the constraints - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); + auto create_info = GetInfo(); + auto &table_info = create_info->Cast(); + + // Modify the column that was specified by 'column_name' + auto &col = table_info.columns.GetColumnMutable(default_idx); + if (col.Generated()) { + throw BinderException("Cannot SET DEFAULT for generated column \"%s\"", col.Name()); } + col.SetDefaultValue(info.expression ? info.expression->Copy() : nullptr); auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); @@ -913,29 +907,28 @@ unique_ptr DuckTableEntry::SetDefault(ClientContext &context, SetD } unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetNotNullInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - auto not_null_idx = GetColumnIndex(info.column_name); if (columns.GetColumn(LogicalIndex(not_null_idx)).Generated()) { throw BinderException("Unsupported constraint for generated column!"); } + + auto create_info = GetInfo(); + auto &table_info = create_info->Cast(); + bool has_not_null = false; - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); + for (auto &constraint : table_info.constraints) { if (constraint->type == ConstraintType::NOT_NULL) { auto ¬_null = constraint->Cast(); if (not_null.index == not_null_idx) { has_not_null = true; + break; } } - create_info->constraints.push_back(std::move(constraint)); } if (!has_not_null) { - create_info->constraints.push_back(make_uniq(not_null_idx)); + table_info.constraints.push_back(make_uniq(not_null_idx)); } + auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); @@ -952,22 +945,21 @@ unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetN } unique_ptr DuckTableEntry::DropNotNull(ClientContext &context, DropNotNullInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - auto not_null_idx = GetColumnIndex(info.column_name); - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - // Skip/drop not_null + + auto create_info = GetInfo(); + auto &table_info = create_info->Cast(); + + // Remove the NOT NULL constraint for the specified column + for (idx_t i = 0; i < table_info.constraints.size(); i++) { + auto &constraint = table_info.constraints[i]; if (constraint->type == ConstraintType::NOT_NULL) { auto ¬_null = constraint->Cast(); if (not_null.index == not_null_idx) { - continue; + table_info.constraints.erase(table_info.constraints.begin() + static_cast(i)); + break; } } - create_info->constraints.push_back(std::move(constraint)); } auto binder = Binder::CreateBinder(context); @@ -1074,27 +1066,17 @@ unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context } unique_ptr DuckTableEntry::SetColumnComment(ClientContext &context, SetColumnCommentInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - auto default_idx = GetColumnIndex(info.column_name); - if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot SET DEFAULT for rowid column"); + auto col_idx = GetColumnIndex(info.column_name); + if (col_idx.index == COLUMN_IDENTIFIER_ROW_ID) { + throw CatalogException("Cannot SET COMMENT for rowid column"); } - // Copy all the columns, changing the value of the one that was specified by 'column_name' - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (default_idx == col.Logical()) { - copy.SetComment(info.comment_value); - } - create_info->columns.AddColumn(std::move(copy)); - } - // Copy all the constraints - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } + auto create_info = GetInfo(); + auto &table_info = create_info->Cast(); + + // Modify the column that was specified by 'column_name' + auto &col = table_info.columns.GetColumnMutable(col_idx); + col.SetComment(info.comment_value); auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); @@ -1199,14 +1181,8 @@ void DuckTableEntry::OnDrop() { } unique_ptr DuckTableEntry::AddConstraint(ClientContext &context, AddConstraintInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - - // Copy all columns and constraints to the modified table. - create_info->columns = columns.Copy(); - for (const auto &constraint : constraints) { - create_info->constraints.push_back(constraint->Copy()); - } + auto create_info = GetInfo(); + auto &table_info = create_info->Cast(); if (info.constraint->type == ConstraintType::UNIQUE) { const auto &unique = info.constraint->Cast(); @@ -1216,7 +1192,7 @@ unique_ptr DuckTableEntry::AddConstraint(ClientContext &context, A auto existing_name = existing_pk->ToString(); throw CatalogException("table \"%s\" can have only one primary key: %s", name, existing_name); } - create_info->constraints.push_back(info.constraint->Copy()); + table_info.constraints.push_back(info.constraint->Copy()); } else { throw InternalException("unsupported constraint type in ALTER TABLE statement"); @@ -1224,7 +1200,7 @@ unique_ptr DuckTableEntry::AddConstraint(ClientContext &context, A // We create a physical table with a new constraint and a new unique index. const auto binder = Binder::CreateBinder(context); - const auto bound_constraint = binder->BindConstraint(*info.constraint, create_info->table, create_info->columns); + const auto bound_constraint = binder->BindConstraint(*info.constraint, table_info.table, table_info.columns); const auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); auto new_storage = make_shared_ptr(context, *storage, *bound_constraint); @@ -1233,15 +1209,8 @@ unique_ptr DuckTableEntry::AddConstraint(ClientContext &context, A } unique_ptr DuckTableEntry::Copy(ClientContext &context) const { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } + D_ASSERT(!internal); + auto create_info = GetInfo(); auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableCheckpoint(std::move(create_info), schema); @@ -1285,8 +1254,8 @@ TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr return TableScanFunction::GetFunction(); } -vector DuckTableEntry::GetColumnSegmentInfo() { - return storage->GetColumnSegmentInfo(); +vector DuckTableEntry::GetColumnSegmentInfo(const QueryContext &context) { + return storage->GetColumnSegmentInfo(context); } TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { diff --git a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp index 2c5cb9ae7..ed71e174c 100644 --- a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -5,7 +5,6 @@ namespace duckdb { IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) : StandardEntry(CatalogType::INDEX_ENTRY, schema, catalog, info.index_name), sql(info.sql), options(info.options), index_type(info.index_type), index_constraint_type(info.constraint_type), column_ids(info.column_ids) { - this->temporary = info.temporary; this->dependencies = info.dependencies; this->comment = info.comment; diff --git a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp index ff247dcb0..9d9789192 100644 --- a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" namespace duckdb { +constexpr const char *PragmaFunctionCatalogEntry::Name; PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreatePragmaFunctionInfo &info) diff --git a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp index 49b20f677..e5778ad4c 100644 --- a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp @@ -5,6 +5,8 @@ namespace duckdb { +constexpr const char *ScalarFunctionCatalogEntry::Name; + ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateScalarFunctionInfo &info) : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { diff --git a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp index 6153a8e8a..d6a548a26 100644 --- a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -13,6 +13,8 @@ namespace duckdb { +constexpr const char *SequenceCatalogEntry::Name; + SequenceData::SequenceData(CreateSequenceInfo &info) : usage_count(info.usage_count), counter(info.start_value), last_value(info.start_value), increment(info.increment), start_value(info.start_value), min_value(info.min_value), max_value(info.max_value), cycle(info.cycle) { diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 22a173fd8..efa26e5cc 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -19,6 +19,8 @@ namespace duckdb { +constexpr const char *TableCatalogEntry::Name; + TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), constraints(std::move(info.constraints)) { @@ -79,6 +81,8 @@ unique_ptr TableCatalogEntry::GetInfo() const { result->dependencies = dependencies; std::for_each(constraints.begin(), constraints.end(), [&result](const unique_ptr &c) { result->constraints.emplace_back(c->Copy()); }); + result->temporary = temporary; + result->internal = internal; result->comment = comment; result->tags = tags; return std::move(result); @@ -266,7 +270,7 @@ void LogicalUpdate::BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, } } -vector TableCatalogEntry::GetColumnSegmentInfo() { +vector TableCatalogEntry::GetColumnSegmentInfo(const QueryContext &context) { return {}; } diff --git a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp index a6a41ff61..f06ef164e 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp @@ -4,6 +4,8 @@ namespace duckdb { +constexpr const char *TableFunctionCatalogEntry::Name; + TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableFunctionInfo &info) : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { diff --git a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp index 0bb4a3f3a..324413b7c 100644 --- a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp @@ -9,6 +9,8 @@ namespace duckdb { +constexpr const char *TypeCatalogEntry::Name; + TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type), bind_function(info.bind_function) { diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp index 6af56c22d..37dd72f72 100644 --- a/src/duckdb/src/catalog/catalog_search_path.cpp +++ b/src/duckdb/src/catalog/catalog_search_path.cpp @@ -8,6 +8,8 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/main/database_manager.hpp" +#include "duckdb/common/exception/parser_exception.hpp" + namespace duckdb { CatalogSearchEntry::CatalogSearchEntry(string catalog_p, string schema_p) @@ -24,8 +26,8 @@ string CatalogSearchEntry::ToString() const { string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',') { - return "\"" + input + "\""; + if (input[i] == '.' || input[i] == ',' || input[i] == '"') { + return "\"" + StringUtil::Replace(input, "\"", "\"\"") + "\""; } } return input; diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index deff8daae..6e5b17610 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -352,6 +352,7 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, map.UpdateEntry(std::move(value)); // push the old entry in the undo buffer for this transaction + unique_ptr entry_to_destroy; if (transaction.transaction) { // serialize the AlterInfo into a temporary buffer MemoryStream stream(Allocator::Get(*transaction.db)); @@ -363,6 +364,10 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, DuckTransactionManager::Get(GetCatalog().GetAttached()) .PushCatalogEntry(*transaction.transaction, new_entry->Child(), stream.GetData(), stream.GetPosition()); + } else { + // if we don't have a transaction this alter is non-transactional + // in that case we are able to just directly destroy the child (if there is any) + entry_to_destroy = new_entry->TakeChild(); } read_lock.unlock(); @@ -370,7 +375,6 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, // Check the dependency manager to verify that there are no conflicting dependencies with this alter catalog.GetDependencyManager()->AlterObject(transaction, *entry, *new_entry, alter_info); - return true; } @@ -401,8 +405,6 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); } - entry->OnDrop(); - // create a new tombstone entry and replace the currently stored one // set the timestamp to the timestamp of the current transaction // and point it at the tombstone node @@ -454,6 +456,7 @@ void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEnt void CatalogSet::CommitDrop(transaction_t commit_id, transaction_t start_time, CatalogEntry &entry) { auto &duck_catalog = GetCatalog(); + entry.OnDrop(); // Make sure that we don't see any uncommitted changes auto transaction_id = MAX_TRANSACTION_ID; // This will allow us to see all committed changes made before this COMMIT happened diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp index f51038b1e..d45bddb71 100644 --- a/src/duckdb/src/catalog/default/default_functions.cpp +++ b/src/duckdb/src/catalog/default/default_functions.cpp @@ -86,6 +86,7 @@ static const DefaultMacro internal_macros[] = { {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, {{nullptr, nullptr}}, "true"}, {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, {{nullptr, nullptr}}, "format_bytes(bytes)"}, + {"pg_catalog", "pg_sleep", {"seconds", nullptr}, {{nullptr, nullptr}}, "sleep_ms(CAST(seconds * 1000 AS BIGINT))"}, {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, {{nullptr, nullptr}}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, {{nullptr, nullptr}}, "round_even(x, n)"}, @@ -98,9 +99,9 @@ static const DefaultMacro internal_macros[] = { {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[2:]"}, {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(arr, list_value(e))"}, {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, // Test default parameters - {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, + {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, {{nullptr, nullptr}}, "unnest(generate_series(1, array_length(arr, dim)))"}, {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, {{nullptr, nullptr}}, "floor(x/y)"}, @@ -115,10 +116,6 @@ static const DefaultMacro internal_macros[] = { {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "l[:-:-1]"}, {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "list_reverse(l)"}, - // FIXME implement as actual function if we encounter a lot of performance issues. Complexity now: n * m, with hashing possibly n + m - {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_filter(list_distinct(l1), lambda variable_intersect: list_contains(l2, variable_intersect))"}, - {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_intersect(l1, l2)"}, - // algebraic list aggregates {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'avg')"}, {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'var_samp')"}, diff --git a/src/duckdb/src/catalog/default/default_table_functions.cpp b/src/duckdb/src/catalog/default/default_table_functions.cpp index c07786474..94079bbcb 100644 --- a/src/duckdb/src/catalog/default/default_table_functions.cpp +++ b/src/duckdb/src/catalog/default/default_table_functions.cpp @@ -69,7 +69,7 @@ FROM histogram_values(source, col_name, bin_count := bin_count, technique := tec {DEFAULT_SCHEMA, "duckdb_logs_parsed", {"log_type"}, {}, R"( SELECT * EXCLUDE (message), UNNEST(parse_duckdb_log_message(log_type, message)) FROM duckdb_logs(denormalized_table=1) -WHERE type = log_type +WHERE type ILIKE log_type )"}, {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} }; diff --git a/src/duckdb/src/catalog/dependency_manager.cpp b/src/duckdb/src/catalog/dependency_manager.cpp index ddd7550a7..840310f7e 100644 --- a/src/duckdb/src/catalog/dependency_manager.cpp +++ b/src/duckdb/src/catalog/dependency_manager.cpp @@ -16,6 +16,8 @@ #include "duckdb/parser/constraints/foreign_key_constraint.hpp" #include "duckdb/catalog/dependency_catalog_set.hpp" +#include "duckdb/common/printer.hpp" + namespace duckdb { static void AssertMangledName(const string &mangled_name, idx_t expected_null_bytes) { diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 054eaaf0f..b461a88c2 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -18,13 +18,22 @@ #include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/parser/keyword_helper.hpp" + // We must leak the symbols of the init function AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *error) { if (!driver) { return ADBC_STATUS_INVALID_ARGUMENT; } + + // Check that the version is supported (1.0.0 or 1.1.0) + if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + auto adbc_driver = static_cast(driver); + // Initialize all 1.0.0 function pointers adbc_driver->DatabaseNew = duckdb_adbc::DatabaseNew; adbc_driver->DatabaseSetOption = duckdb_adbc::DatabaseSetOption; adbc_driver->DatabaseInit = duckdb_adbc::DatabaseInit; @@ -50,12 +59,51 @@ AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *err adbc_driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; adbc_driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; adbc_driver->ConnectionGetTableSchema = duckdb_adbc::ConnectionGetTableSchema; + + // Initialize 1.1.0 function pointers if version >= 1.1.0 + if (version >= ADBC_VERSION_1_1_0) { + // TODO: ADBC 1.1.0 adds support for these functions + adbc_driver->ErrorGetDetailCount = nullptr; + adbc_driver->ErrorGetDetail = nullptr; + adbc_driver->ErrorFromArrayStream = nullptr; + + adbc_driver->DatabaseGetOption = nullptr; + adbc_driver->DatabaseGetOptionBytes = nullptr; + adbc_driver->DatabaseGetOptionDouble = nullptr; + adbc_driver->DatabaseGetOptionInt = nullptr; + adbc_driver->DatabaseSetOptionBytes = nullptr; + adbc_driver->DatabaseSetOptionInt = nullptr; + adbc_driver->DatabaseSetOptionDouble = nullptr; + + adbc_driver->ConnectionCancel = nullptr; + adbc_driver->ConnectionGetOption = nullptr; + adbc_driver->ConnectionGetOptionBytes = nullptr; + adbc_driver->ConnectionGetOptionDouble = nullptr; + adbc_driver->ConnectionGetOptionInt = nullptr; + adbc_driver->ConnectionGetStatistics = nullptr; + adbc_driver->ConnectionGetStatisticNames = nullptr; + adbc_driver->ConnectionSetOptionBytes = nullptr; + adbc_driver->ConnectionSetOptionInt = nullptr; + adbc_driver->ConnectionSetOptionDouble = nullptr; + + adbc_driver->StatementCancel = nullptr; + adbc_driver->StatementExecuteSchema = nullptr; + adbc_driver->StatementGetOption = nullptr; + adbc_driver->StatementGetOptionBytes = nullptr; + adbc_driver->StatementGetOptionDouble = nullptr; + adbc_driver->StatementGetOptionInt = nullptr; + adbc_driver->StatementSetOptionBytes = nullptr; + adbc_driver->StatementSetOptionDouble = nullptr; + adbc_driver->StatementSetOptionInt = nullptr; + } + return ADBC_STATUS_OK; } namespace duckdb_adbc { -enum class IngestionMode { CREATE = 0, APPEND = 1 }; +// ADBC 1.1.0: Added REPLACE and CREATE_APPEND modes +enum class IngestionMode { CREATE = 0, APPEND = 1, REPLACE = 2, CREATE_APPEND = 3 }; struct DuckDBAdbcStatementWrapper { duckdb_connection connection; @@ -196,7 +244,6 @@ AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *err } AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - if (database && database->private_data) { auto wrapper = static_cast(database->private_data); @@ -537,7 +584,8 @@ static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) auto count = duckdb_column_count(&result_wrapper->result); std::vector types(count); - std::vector owned_names(count); + std::vector owned_names; + owned_names.reserve(count); duckdb::vector names(count); for (idx_t i = 0; i < count; i++) { types[i] = duckdb_column_logical_type(&result_wrapper->result, i); @@ -548,7 +596,7 @@ static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) auto arrow_options = duckdb_result_get_arrow_options(&result_wrapper->result); - auto res = duckdb_to_arrow_schema(arrow_options, &types[0], names.data(), count, out); + auto res = duckdb_to_arrow_schema(arrow_options, types.data(), names.data(), count, out); duckdb_destroy_arrow_options(&arrow_options); for (auto &type : types) { duckdb_destroy_logical_type(&type); @@ -605,7 +653,6 @@ const char *get_last_error(struct ArrowArrayStream *stream) { duckdb::unique_ptr stream_produce(uintptr_t factory_ptr, duckdb::ArrowStreamParameters ¶meters) { - // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine auto res = duckdb::make_uniq(); res->arrow_array_stream = *reinterpret_cast(factory_ptr); @@ -616,10 +663,44 @@ void stream_schema(ArrowArrayStream *stream, ArrowSchema &schema) { stream->get_schema(stream, &schema); } +// Helper function to build CREATE TABLE SQL statement +static std::string BuildCreateTableSQL(const char *schema, const char *table_name, + const duckdb::vector &types, + const duckdb::vector &names, bool if_not_exists = false) { + std::ostringstream create_table; + create_table << "CREATE TABLE "; + if (if_not_exists) { + create_table << "IF NOT EXISTS "; + } + if (schema) { + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(schema) << "."; + } + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(table_name) << " ("; + for (idx_t i = 0; i < types.size(); i++) { + create_table << duckdb::KeywordHelper::WriteOptionallyQuoted(names[i]); + create_table << " " << types[i].ToString(); + if (i + 1 < types.size()) { + create_table << ", "; + } + } + create_table << ");"; + return create_table.str(); +} + +// Helper function to build DROP TABLE IF EXISTS SQL statement +static std::string BuildDropTableSQL(const char *schema, const char *table_name) { + std::ostringstream drop_table; + drop_table << "DROP TABLE IF EXISTS "; + if (schema) { + drop_table << duckdb::KeywordHelper::WriteOptionallyQuoted(schema) << "."; + } + drop_table << duckdb::KeywordHelper::WriteOptionallyQuoted(table_name); + return drop_table.str(); +} + AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, const char *schema, struct ArrowArrayStream *input, struct AdbcError *error, IngestionMode ingestion_mode, - bool temporary) { - + bool temporary, int64_t *rows_affected) { if (!connection) { SetError(error, "Missing connection object"); return ADBC_STATUS_INVALID_ARGUMENT; @@ -654,29 +735,60 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, cons auto types = d_converted_schema.GetTypes(); auto names = d_converted_schema.GetNames(); - if (ingestion_mode == IngestionMode::CREATE) { - // We must construct the create table SQL query - std::ostringstream create_table; - create_table << "CREATE TABLE "; - if (schema) { - create_table << schema << "."; - } - create_table << table_name << " ("; - for (idx_t i = 0; i < types.size(); i++) { - create_table << names[i] << " "; - create_table << types[i].ToString(); - if (i + 1 < types.size()) { - create_table << ", "; + // Handle different ingestion modes + switch (ingestion_mode) { + case IngestionMode::CREATE: { + // CREATE mode: Create table, error if already exists + auto sql = BuildCreateTableSQL(schema, table_name, types, names); + duckdb_result result; + if (duckdb_query(connection, sql.c_str(), &result) == DuckDBError) { + const char *error_msg = duckdb_result_error(&result); + // Check if error is about table already existing before destroying result + bool already_exists = error_msg && std::string(error_msg).find("already exists") != std::string::npos; + duckdb_destroy_result(&result); + if (already_exists) { + return ADBC_STATUS_ALREADY_EXISTS; } + return ADBC_STATUS_INTERNAL; } - create_table << ");"; + duckdb_destroy_result(&result); + break; + } + case IngestionMode::APPEND: + // APPEND mode: No pre-check needed + // The appender will naturally fail if the table doesn't exist + break; + case IngestionMode::REPLACE: { + // REPLACE mode: Drop table if exists, then create + auto drop_sql = BuildDropTableSQL(schema, table_name); + auto create_sql = BuildCreateTableSQL(schema, table_name, types, names); duckdb_result result; - if (duckdb_query(connection, create_table.str().c_str(), &result) == DuckDBError) { + if (duckdb_query(connection, drop_sql.c_str(), &result) == DuckDBError) { + SetError(error, duckdb_result_error(&result)); + duckdb_destroy_result(&result); + return ADBC_STATUS_INTERNAL; + } + duckdb_destroy_result(&result); + if (duckdb_query(connection, create_sql.c_str(), &result) == DuckDBError) { SetError(error, duckdb_result_error(&result)); duckdb_destroy_result(&result); return ADBC_STATUS_INTERNAL; } duckdb_destroy_result(&result); + break; + } + case IngestionMode::CREATE_APPEND: { + // CREATE_APPEND mode: Create if not exists, append if exists + auto sql = BuildCreateTableSQL(schema, table_name, types, names, true); + duckdb_result result; + if (duckdb_query(connection, sql.c_str(), &result) == DuckDBError) { + SetError(error, duckdb_result_error(&result)); + duckdb_destroy_result(&result); + return ADBC_STATUS_INTERNAL; + } + duckdb_destroy_result(&result); + break; + } } AppenderWrapper appender(connection, schema, table_name); if (!appender.Valid()) { @@ -684,6 +796,9 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, cons } duckdb::ArrowArrayWrapper arrow_array_wrapper; + // Initialize rows_affected counter if requested + int64_t affected = 0; + input->get_next(input, &arrow_array_wrapper.arrow_array); while (arrow_array_wrapper.arrow_array.release) { DataChunkWrapper out_chunk; @@ -693,12 +808,23 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, cons SetError(error, duckdb_error_data_message(res)); duckdb_destroy_error_data(&res); } + // Count rows for rows_affected, if a chunk was produced + if (out_chunk.chunk) { + auto *chunk = reinterpret_cast(out_chunk.chunk); + affected += static_cast(chunk->size()); + } if (duckdb_append_data_chunk(appender.Get(), out_chunk.chunk) != DuckDBSuccess) { + auto error_data = duckdb_appender_error_data(appender.Get()); + SetError(error, duckdb_error_data_message(error_data)); + duckdb_destroy_error_data(&error_data); return ADBC_STATUS_INTERNAL; } arrow_array_wrapper = duckdb::ArrowArrayWrapper(); input->get_next(input, &arrow_array_wrapper.arrow_array); } + if (rows_affected) { + *rows_affected = affected; + } return ADBC_STATUS_OK; } @@ -789,11 +915,9 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru return ADBC_STATUS_INVALID_ARGUMENT; } auto count = prepared_wrapper->statement->data->properties.parameter_count; - if (count == 0) { - count = 1; - } std::vector types(count); - std::vector owned_names(count); + std::vector owned_names; + owned_names.reserve(count); duckdb::vector names(count); for (idx_t i = 0; i < count; i++) { @@ -809,7 +933,7 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru duckdb_arrow_options arrow_options; duckdb_connection_get_arrow_options(wrapper->connection, &arrow_options); - auto res = duckdb_to_arrow_schema(arrow_options, &types[0], names.data(), count, schema); + auto res = duckdb_to_arrow_schema(arrow_options, types.data(), names.data(), count, schema); for (auto &type : types) { duckdb_destroy_logical_type(&type); @@ -824,7 +948,8 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru return ADBC_STATUS_OK; } -static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, AdbcError *error) { +static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, int64_t *rows_affected, + AdbcError *error) { // See ADBC_INGEST_OPTION_TARGET_TABLE D_ASSERT(statement->ingestion_stream.release); D_ASSERT(statement->ingestion_table_name); @@ -834,7 +959,7 @@ static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *s // Ingest into a table from the bound stream return Ingest(statement->connection, statement->ingestion_table_name, statement->db_schema, &stream, error, - statement->ingestion_mode, statement->temporary_table); + statement->ingestion_mode, statement->temporary_table, rows_affected); } AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, @@ -858,10 +983,18 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr const auto to_table = wrapper->ingestion_table_name != nullptr; if (has_stream && to_table) { - return IngestToTableFromBoundStream(wrapper, error); + return IngestToTableFromBoundStream(wrapper, rows_affected, error); } auto stream_wrapper = static_cast(malloc(sizeof(DuckDBAdbcStreamWrapper))); - if (has_stream) { + if (!stream_wrapper) { + SetError(error, "Allocation error"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + std::memset(&stream_wrapper->result, 0, sizeof(stream_wrapper->result)); + // Only process the stream if there are parameters to bind + auto prepared_statement_params = reinterpret_cast(wrapper->statement) + ->statement->data->properties.parameter_count; + if (has_stream && prepared_statement_params > 0) { // A stream was bound to the statement, use that to bind parameters ArrowArrayStream stream = wrapper->ingestion_stream; ConvertedSchemaWrapper out_types; @@ -878,8 +1011,6 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr free(stream_wrapper); return ADBC_STATUS_INTERNAL; } - auto prepared_statement_params = - reinterpret_cast(wrapper->statement)->statement->named_param_map.size(); duckdb::ArrowArrayWrapper arrow_array_wrapper; @@ -928,9 +1059,12 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr return ADBC_STATUS_INVALID_ARGUMENT; } } + // Destroy any previous result before overwriting to avoid leaks + duckdb_destroy_result(&stream_wrapper->result); auto res = duckdb_execute_prepared(wrapper->statement, &stream_wrapper->result); if (res != DuckDBSuccess) { SetError(error, duckdb_result_error(&stream_wrapper->result)); + duckdb_destroy_result(&stream_wrapper->result); free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -942,10 +1076,27 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr auto res = duckdb_execute_prepared(wrapper->statement, &stream_wrapper->result); if (res != DuckDBSuccess) { SetError(error, duckdb_result_error(&stream_wrapper->result)); + duckdb_destroy_result(&stream_wrapper->result); + free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } } + // Set rows_affected for queries (if not already set by ingestion path) + if (rows_affected && !(has_stream && to_table)) { + // For DML queries (INSERT/UPDATE/DELETE), duckdb_rows_changed() returns the count + // For SELECT queries, duckdb_rows_changed() returns 0 + auto rows_changed = duckdb_rows_changed(&stream_wrapper->result); + if (rows_changed > 0) { + // This was a DML query + *rows_affected = static_cast(rows_changed); + } else { + // This is a SELECT or other query that returns a result set + // Return -1 to indicate unknown, as results are streamed + *rows_affected = -1; + } + } + if (out) { // We pass ownership of the statement private data to our stream out->private_data = stream_wrapper; @@ -953,6 +1104,10 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr out->get_next = get_next; out->release = release; out->get_last_error = get_last_error; + } else { + // Caller didn't request a stream; clean up resources + duckdb_destroy_result(&stream_wrapper->result); + free(stream_wrapper); } return ADBC_STATUS_OK; @@ -1148,6 +1303,12 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *k } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { wrapper->ingestion_mode = IngestionMode::APPEND; return ADBC_STATUS_OK; + } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_REPLACE) == 0) { + wrapper->ingestion_mode = IngestionMode::REPLACE; + return ADBC_STATUS_OK; + } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE_APPEND) == 0) { + wrapper->ingestion_mode = IngestionMode::CREATE_APPEND; + return ADBC_STATUS_OK; } else { SetError(error, "Invalid ingestion mode"); return ADBC_STATUS_INVALID_ARGUMENT; @@ -1484,3 +1645,34 @@ AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct } } // namespace duckdb_adbc + +static void ReleaseError(struct AdbcError *error) { + if (error) { + if (error->message) + delete[] error->message; + error->message = nullptr; + error->release = nullptr; + } +} + +void SetError(struct AdbcError *error, const std::string &message) { + if (!error) + return; + if (error->message) { + // Append + std::string buffer = error->message; + buffer.reserve(buffer.size() + message.size() + 1); + buffer += '\n'; + buffer += message; + error->release(error); + + error->message = new char[buffer.size() + 1]; + buffer.copy(error->message, buffer.size()); + error->message[buffer.size()] = '\0'; + } else { + error->message = new char[message.size() + 1]; + message.copy(error->message, message.size()); + error->message[message.size()] = '\0'; + } + error->release = ReleaseError; +} diff --git a/src/duckdb/src/common/adbc/driver_manager.cpp b/src/duckdb/src/common/adbc/driver_manager.cpp deleted file mode 100644 index 45fb8c24d..000000000 --- a/src/duckdb/src/common/adbc/driver_manager.cpp +++ /dev/null @@ -1,1626 +0,0 @@ -//////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////// -// THIS FILE IS GENERATED BY apache/arrow, DO NOT EDIT MANUALLY // -//////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////// - -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/adbc/driver_manager.h" -#include "duckdb/common/adbc/adbc.h" -#include "duckdb/common/adbc/adbc.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#if defined(_WIN32) -#include // Must come first - -#include -#include -#else -#include -#endif // defined(_WIN32) - -// Platform-specific helpers - -#if defined(_WIN32) -/// Append a description of the Windows error to the buffer. -void GetWinError(std::string *buffer) { - DWORD rc = GetLastError(); - LPVOID message; - - FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); - - (*buffer) += '('; - (*buffer) += std::to_string(rc); - (*buffer) += ") "; - (*buffer) += reinterpret_cast(message); - LocalFree(message); -} - -#endif // defined(_WIN32) - -// Error handling - -void ReleaseError(struct AdbcError *error) { - if (error) { - if (error->message) - delete[] error->message; - error->message = nullptr; - error->release = nullptr; - } -} - -void SetError(struct AdbcError *error, const std::string &message) { - if (!error) - return; - if (error->message) { - // Append - std::string buffer = error->message; - buffer.reserve(buffer.size() + message.size() + 1); - buffer += '\n'; - buffer += message; - error->release(error); - - error->message = new char[buffer.size() + 1]; - buffer.copy(error->message, buffer.size()); - error->message[buffer.size()] = '\0'; - } else { - error->message = new char[message.size() + 1]; - message.copy(error->message, message.size()); - error->message[message.size()] = '\0'; - } - error->release = ReleaseError; -} - -// Driver state - -/// A driver DLL. -struct ManagedLibrary { - ManagedLibrary() : handle(nullptr) { - } - ManagedLibrary(ManagedLibrary &&other) : handle(other.handle) { - other.handle = nullptr; - } - ManagedLibrary(const ManagedLibrary &) = delete; - ManagedLibrary &operator=(const ManagedLibrary &) = delete; - ManagedLibrary &operator=(ManagedLibrary &&other) noexcept { - this->handle = other.handle; - other.handle = nullptr; - return *this; - } - - ~ManagedLibrary() { - Release(); - } - - void Release() { - // TODO(apache/arrow-adbc#204): causes tests to segfault - // Need to refcount the driver DLL; also, errors may retain a reference to - // release() from the DLL - how to handle this? - } - - AdbcStatusCode Load(const char *library, struct AdbcError *error) { - std::string error_message; -#if defined(_WIN32) - HMODULE handle = LoadLibraryExA(library, NULL, 0); - if (!handle) { - error_message += library; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = library; - full_driver_name += ".dll"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; - } else { - this->handle = handle; - } -#else - const std::string kPlatformLibraryPrefix = "lib"; -#if defined(__APPLE__) - const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void *handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = library; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += library; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); - } - } - if (handle) { - this->handle = handle; - } else { - return ADBC_STATUS_INTERNAL; - } -#endif // defined(_WIN32) - return ADBC_STATUS_OK; - } - - AdbcStatusCode Lookup(const char *name, void **func, struct AdbcError *error) { -#if defined(_WIN32) - void *load_handle = reinterpret_cast(GetProcAddress(handle, name)); - if (!load_handle) { - std::string message = "GetProcAddress("; - message += name; - message += ") failed: "; - GetWinError(&message); - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } -#else - void *load_handle = dlsym(handle, name); - if (!load_handle) { - std::string message = "dlsym("; - message += name; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } -#endif // defined(_WIN32) - *func = load_handle; - return ADBC_STATUS_OK; - } - -#if defined(_WIN32) - // The loaded DLL - HMODULE handle; -#else - void *handle; -#endif // defined(_WIN32) -}; - -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); - - ManagedLibrary handle; -}; - -/// Unload the driver DLL. -static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { - AdbcStatusCode status = ADBC_STATUS_OK; - - if (!driver->private_manager) - return status; - ManagerDriverState *state = reinterpret_cast(driver->private_manager); - - if (state->driver_release) { - status = state->driver_release(driver, error); - } - state->handle.Release(); - - driver->private_manager = nullptr; - delete state; - return status; -} - -// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream - -struct ErrorArrayStream { - struct ArrowArrayStream stream; - struct AdbcDriver *private_driver; -}; - -void ErrorArrayStreamRelease(struct ArrowArrayStream *stream) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return; - - auto *private_data = reinterpret_cast(stream->private_data); - private_data->stream.release(&private_data->stream); - delete private_data; - std::memset(stream, 0, sizeof(*stream)); -} - -const char *ErrorArrayStreamGetLastError(struct ArrowArrayStream *stream) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return nullptr; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_last_error(&private_data->stream); -} - -int ErrorArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *array) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return EINVAL; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_next(&private_data->stream, array); -} - -int ErrorArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *schema) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return EINVAL; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_schema(&private_data->stream, schema); -} - -// Default stubs - -int ErrorGetDetailCount(const struct AdbcError *error) { - return 0; -} - -struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError *error, int index) { - return {nullptr, nullptr, 0}; -} - -const struct AdbcError *ErrorFromArrayStream(struct ArrowArrayStream *stream, AdbcStatusCode *status) { - return nullptr; -} - -void ErrorArrayStreamInit(struct ArrowArrayStream *out, struct AdbcDriver *private_driver) { - if (!out || !out->release || - // Don't bother wrapping if driver didn't claim support - private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { - return; - } - struct ErrorArrayStream *private_data = new ErrorArrayStream; - private_data->stream = *out; - private_data->private_driver = private_driver; - out->get_last_error = ErrorArrayStreamGetLastError; - out->get_next = ErrorArrayStreamGetNext; - out->get_schema = ErrorArrayStreamGetSchema; - out->release = ErrorArrayStreamRelease; - out->private_data = private_data; -} - -AdbcStatusCode DatabaseGetOption(struct AdbcDatabase *database, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase *database, const char *key, uint8_t *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase *database, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase *database, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase *database, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase *database, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase *database, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCancel(struct AdbcConnection *connection, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetOption(struct AdbcConnection *connection, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection *connection, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection *connection, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection *connection, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection *, const char *, const char *, const char *, char, - struct ArrowArrayStream *, struct AdbcError *) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection *, struct ArrowArrayStream *, struct AdbcError *) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection *, const char *, const uint8_t *, size_t, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection *connection, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection *connection, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementCancel(struct AdbcStatement *statement, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementGetOption(struct AdbcStatement *statement, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement *statement, const char *key, uint8_t *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionInt(struct AdbcStatement *statement, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement *statement, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement *, const char *, const uint8_t *, size_t, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOptionInt(struct AdbcStatement *statement, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement *statement, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -/// Temporary state while the database is being configured. -struct TempDatabase { - std::unordered_map options; - std::unordered_map bytes_options; - std::unordered_map int_options; - std::unordered_map double_options; - std::string driver; - std::string entrypoint; - AdbcDriverInitFunc init_func = nullptr; -}; - -/// Temporary state while the database is being configured. -struct TempConnection { - std::unordered_map options; - std::unordered_map bytes_options; - std::unordered_map int_options; - std::unordered_map double_options; -}; - -static const char kDefaultEntrypoint[] = "AdbcDriverInit"; - -// Other helpers (intentionally not in an anonymous namespace so they can be tested) - -ADBC_EXPORT -std::string AdbcDriverManagerDefaultEntrypoint(const std::string &driver) { - /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit - /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit - /// - proprietary_driver.dll -> AdbcProprietaryDriverInit - - // Potential path -> filename - // Treat both \ and / as directory separators on all platforms for simplicity - std::string filename; - { - size_t pos = driver.find_last_of("/\\"); - if (pos != std::string::npos) { - filename = driver.substr(pos + 1); - } else { - filename = driver; - } - } - - // Remove all extensions - { - size_t pos = filename.find('.'); - if (pos != std::string::npos) { - filename = filename.substr(0, pos); - } - } - - // Remove lib prefix - // https://stackoverflow.com/q/1878001/262727 - if (filename.rfind("lib", 0) == 0) { - filename = filename.substr(3); - } - - // Split on underscores, hyphens - // Capitalize and join - std::string entrypoint; - entrypoint.reserve(filename.size()); - size_t pos = 0; - while (pos < filename.size()) { - size_t prev = pos; - pos = filename.find_first_of("-_", pos); - // if pos == npos this is the entire filename - std::string token = filename.substr(prev, pos - prev); - // capitalize first letter - token[0] = duckdb::NumericCast(std::toupper(static_cast(token[0]))); - - entrypoint += token; - - if (pos != std::string::npos) { - pos++; - } - } - - if (entrypoint.rfind("Adbc", 0) != 0) { - entrypoint = "Adbc" + entrypoint; - } - entrypoint += "Init"; - - return entrypoint; -} - -// Direct implementations of API methods - -int AdbcErrorGetDetailCount(const struct AdbcError *error) { - if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && error->private_driver) { - return error->private_driver->ErrorGetDetailCount(error); - } - return 0; -} - -struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError *error, int index) { - if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && error->private_driver) { - return error->private_driver->ErrorGetDetail(error, index); - } - return {nullptr, nullptr, 0}; -} - -const struct AdbcError *AdbcErrorFromArrayStream(struct ArrowArrayStream *stream, AdbcStatusCode *status) { - if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { - return nullptr; - } - auto *private_data = reinterpret_cast(stream->private_data); - auto *error = private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); - if (error) { - const_cast(error)->private_driver = private_data->private_driver; - } - return error; -} - -#define INIT_ERROR(ERROR, SOURCE) \ - if ((ERROR) != nullptr && (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ - (ERROR)->private_driver = (SOURCE)->private_driver; \ - } - -#define WRAP_STREAM(EXPR, OUT, SOURCE) \ - if (!(OUT)) { \ - /* Happens for ExecuteQuery where out is optional */ \ - return EXPR; \ - } \ - AdbcStatusCode status_code = EXPR; \ - ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ - return status_code; - -AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCommit(struct AdbcConnection *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, const uint32_t *info_codes, - size_t info_codes_length, struct ArrowArrayStream *out, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *, int, const char *, const char *, const char *, - const char **, const char *, struct ArrowArrayStream *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *, const char *, const char *, const char *, - struct ArrowSchema *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *, struct ArrowArrayStream *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionRollback(struct AdbcConnection *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOption(struct AdbcConnection *, const char *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementBind(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementPrepare(struct AdbcStatement *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOption(struct AdbcStatement *, const char *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init - database->private_data = new TempDatabase(); - database->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase *database, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOption(database, key, value, length, error); - } - const auto *args = reinterpret_cast(database->private_data); - const std::string *result = nullptr; - if (std::strcmp(key, "driver") == 0) { - result = &args->driver; - } else if (std::strcmp(key, "entrypoint") == 0) { - result = &args->entrypoint; - } else { - const auto it = args->options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - result = &it->second; - } - - if (*length <= result->size() + 1) { - // Enough space - std::memcpy(value, result->c_str(), result->size() + 1); - } - *length = result->size() + 1; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase *database, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->bytes_options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - const std::string &result = it->second; - - if (*length <= result.size()) { - // Enough space - std::memcpy(value, result.c_str(), result.size()); - } - *length = result.size(); - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase *database, const char *key, int64_t *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionInt(database, key, value, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->int_options.find(key); - if (it == args->int_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase *database, const char *key, double *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->double_options.find(key); - if (it == args->double_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOption(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - if (std::strcmp(key, "driver") == 0) { - args->driver = value; - } else if (std::strcmp(key, "entrypoint") == 0) { - args->entrypoint = value; - } else { - args->options[key] = value; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase *database, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->bytes_options[key] = std::string(reinterpret_cast(value), length); - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase *database, const char *key, int64_t value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionInt(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->int_options[key] = value; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase *database, const char *key, double value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->double_options[key] = value; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, - struct AdbcError *error) { - if (database->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->init_func = init_func; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_data) { - SetError(error, "Must call AdbcDatabaseNew first"); - return ADBC_STATUS_INVALID_STATE; - } - TempDatabase *args = reinterpret_cast(database->private_data); - if (args->init_func) { - // Do nothing - } else if (args->driver.empty()) { - SetError(error, "Must provide 'driver' parameter"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - database->private_driver = new AdbcDriver; - std::memset(database->private_driver, 0, sizeof(AdbcDriver)); - AdbcStatusCode status; - // So we don't confuse a driver into thinking it's initialized already - database->private_data = nullptr; - if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else if (!args->entrypoint.empty()) { - status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_1_1_0, - database->private_driver, error); - } else { - status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, database->private_driver, error); - } - if (status != ADBC_STATUS_OK) { - // Restore private_data so it will be released by AdbcDatabaseRelease - database->private_data = args; - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - status = database->private_driver->DatabaseNew(database, error); - if (status != ADBC_STATUS_OK) { - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - auto options = std::move(args->options); - auto bytes_options = std::move(args->bytes_options); - auto int_options = std::move(args->int_options); - auto double_options = std::move(args->double_options); - delete args; - - INIT_ERROR(error, database); - for (const auto &option : options) { - status = - database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : bytes_options) { - status = database->private_driver->DatabaseSetOptionBytes( - database, option.first.c_str(), reinterpret_cast(option.second.data()), - option.second.size(), error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : int_options) { - status = database->private_driver->DatabaseSetOptionInt(database, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : double_options) { - status = - database->private_driver->DatabaseSetOptionDouble(database, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - break; - } - - if (status != ADBC_STATUS_OK) { - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; - } - return database->private_driver->DatabaseInit(database, error); -} - -AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_driver) { - if (database->private_data) { - TempDatabase *args = reinterpret_cast(database->private_data); - delete args; - database->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, database); - auto status = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_data = nullptr; - database->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionCancel(connection, error); -} - -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionCommit(connection, error); -} - -AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, const uint32_t *info_codes, - size_t info_codes_length, struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, - const char *db_schema, const char *table_name, const char **table_types, - const char *column_name, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_types, column_name, stream, error), - stream, connection); -} - -AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection *connection, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - if (*length >= it->second.size() + 1) { - std::memcpy(value, it->second.c_str(), it->second.size() + 1); - } - *length = it->second.size() + 1; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOption(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection *connection, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->bytes_options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - if (*length >= it->second.size() + 1) { - std::memcpy(value, it->second.data(), it->second.size() + 1); - } - *length = it->second.size() + 1; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection *connection, const char *key, int64_t *value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->int_options.find(key); - if (it == args->int_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionInt(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection *connection, const char *key, double *value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->double_options.find(key); - if (it == args->double_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection *connection, const char *catalog, - const char *db_schema, const char *table_name, char approximate, - struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetStatistics(connection, catalog, db_schema, table_name, - approximate == 1, out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection *connection, struct ArrowArrayStream *out, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetStatisticNames(connection, out, error), out, connection); -} - -AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, - const char *db_schema, const char *table_name, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, - error); -} - -AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetTableTypes(connection, stream, error), stream, connection); -} - -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, - struct AdbcError *error) { - - if (!connection->private_data) { - SetError(error, "Must call AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } else if (!database->private_driver) { - SetError(error, "Database is not initialized"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - TempConnection *args = reinterpret_cast(connection->private_data); - connection->private_data = nullptr; - std::unordered_map options = std::move(args->options); - std::unordered_map bytes_options = std::move(args->bytes_options); - std::unordered_map int_options = std::move(args->int_options); - std::unordered_map double_options = std::move(args->double_options); - delete args; - - auto status = database->private_driver->ConnectionNew(connection, error); - if (status != ADBC_STATUS_OK) - return status; - connection->private_driver = database->private_driver; - - for (const auto &option : options) { - status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), - error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : bytes_options) { - status = database->private_driver->ConnectionSetOptionBytes( - connection, option.first.c_str(), reinterpret_cast(option.second.data()), - option.second.size(), error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : int_options) { - status = - database->private_driver->ConnectionSetOptionInt(connection, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : double_options) { - status = - database->private_driver->ConnectionSetOptionDouble(connection, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - return status; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionInit(connection, database, error); -} - -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init, because - // we don't get access to the database (and hence the driver - // function table) until then - connection->private_data = new TempConnection; - connection->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, - out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - if (connection->private_data) { - TempConnection *args = reinterpret_cast(connection->private_data); - delete args; - connection->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - auto status = connection->private_driver->ConnectionRelease(connection, error); - connection->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionRollback(connection, error); -} - -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, - struct AdbcError *error) { - if (!connection || !connection->private_data) { - SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOption(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection *connection, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->bytes_options[key] = std::string(reinterpret_cast(value), length); - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection *connection, const char *key, int64_t value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->int_options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionInt(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection *connection, const char *key, double value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->double_options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, error); -} - -AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementBind(statement, values, schema, error); -} - -AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementBindStream(statement, stream, error); -} - -AdbcStatusCode AdbcStatementCancel(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementCancel(statement, error); -} - -// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); -} - -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, - int64_t *rows_affected, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error), out, statement); -} - -AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementExecuteSchema(statement, schema, error); -} - -AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement *statement, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOption(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement *statement, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement *statement, const char *key, int64_t *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionInt(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement *statement, const char *key, double *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionDouble(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetParameterSchema(statement, schema, error); -} - -AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - auto status = connection->private_driver->StatementNew(connection, statement, error); - statement->private_driver = connection->private_driver; - return status; -} - -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementPrepare(statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - auto status = statement->private_driver->StatementRelease(statement, error); - statement->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOption(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement *statement, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement *statement, const char *key, int64_t value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionInt(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement *statement, const char *key, double value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionDouble(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetSqlQuery(statement, query, error); -} - -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); -} - -const char *AdbcStatusCodeMessage(AdbcStatusCode code) { -#define CASE(CONSTANT) \ - case ADBC_STATUS_##CONSTANT: \ - return #CONSTANT; - - switch (code) { - CASE(OK); - CASE(UNKNOWN); - CASE(NOT_IMPLEMENTED); - CASE(NOT_FOUND); - CASE(ALREADY_EXISTS); - CASE(INVALID_ARGUMENT); - CASE(INVALID_STATE); - CASE(INVALID_DATA); - CASE(INTEGRITY); - CASE(INTERNAL); - CASE(IO); - CASE(CANCELLED); - CASE(TIMEOUT); - CASE(UNAUTHENTICATED); - CASE(UNAUTHORIZED); - default: - return "(invalid code)"; - } -#undef CASE -} - -AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, - struct AdbcError *error) { - AdbcDriverInitFunc init_func; - std::string error_message; - - switch (version) { - case ADBC_VERSION_1_0_0: - case ADBC_VERSION_1_1_0: - break; - default: - SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - if (!raw_driver) { - SetError(error, "Must provide non-NULL raw_driver"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto *driver = reinterpret_cast(raw_driver); - - ManagedLibrary library; - AdbcStatusCode status = library.Load(driver_name, error); - if (status != ADBC_STATUS_OK) { - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return status; - } - - void *load_handle = nullptr; - if (entrypoint) { - status = library.Lookup(entrypoint, &load_handle, error); - } else { - auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); - status = library.Lookup(name.c_str(), &load_handle, error); - if (status != ADBC_STATUS_OK) { - status = library.Lookup(kDefaultEntrypoint, &load_handle, error); - } - } - - if (status != ADBC_STATUS_OK) { - library.Release(); - return status; - } - init_func = reinterpret_cast(load_handle); - - status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); - if (status == ADBC_STATUS_OK) { - ManagerDriverState *state = new ManagerDriverState; - state->driver_release = driver->release; - state->handle = std::move(library); - driver->release = &ReleaseDriver; - driver->private_manager = state; - } else { - library.Release(); - } - return status; -} - -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, - struct AdbcError *error) { - constexpr std::array kSupportedVersions = { - ADBC_VERSION_1_1_0, - ADBC_VERSION_1_0_0, - }; - - if (!raw_driver) { - SetError(error, "Must provide non-NULL raw_driver"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - switch (version) { - case ADBC_VERSION_1_0_0: - case ADBC_VERSION_1_1_0: - break; - default: - SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - -#define FILL_DEFAULT(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - DRIVER->STUB = &STUB; \ - } -#define CHECK_REQUIRED(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - SetError(error, "Driver does not implement required function Adbc" #STUB); \ - return ADBC_STATUS_INTERNAL; \ - } - - // Starting from the passed version, try each (older) version in - // succession with the underlying driver until we find one that's - // accepted. - AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; - for (const int try_version : kSupportedVersions) { - if (try_version > version) - continue; - result = init_func(try_version, raw_driver, error); - if (result != ADBC_STATUS_NOT_IMPLEMENTED) - break; - } - if (result != ADBC_STATUS_OK) { - return result; - } - - if (version >= ADBC_VERSION_1_0_0) { - auto *driver = reinterpret_cast(raw_driver); - CHECK_REQUIRED(driver, DatabaseNew); - CHECK_REQUIRED(driver, DatabaseInit); - CHECK_REQUIRED(driver, DatabaseRelease); - FILL_DEFAULT(driver, DatabaseSetOption); - - CHECK_REQUIRED(driver, ConnectionNew); - CHECK_REQUIRED(driver, ConnectionInit); - CHECK_REQUIRED(driver, ConnectionRelease); - FILL_DEFAULT(driver, ConnectionCommit); - FILL_DEFAULT(driver, ConnectionGetInfo); - FILL_DEFAULT(driver, ConnectionGetObjects); - FILL_DEFAULT(driver, ConnectionGetTableSchema); - FILL_DEFAULT(driver, ConnectionGetTableTypes); - FILL_DEFAULT(driver, ConnectionReadPartition); - FILL_DEFAULT(driver, ConnectionRollback); - FILL_DEFAULT(driver, ConnectionSetOption); - - FILL_DEFAULT(driver, StatementExecutePartitions); - CHECK_REQUIRED(driver, StatementExecuteQuery); - CHECK_REQUIRED(driver, StatementNew); - CHECK_REQUIRED(driver, StatementRelease); - FILL_DEFAULT(driver, StatementBind); - FILL_DEFAULT(driver, StatementGetParameterSchema); - FILL_DEFAULT(driver, StatementPrepare); - FILL_DEFAULT(driver, StatementSetOption); - FILL_DEFAULT(driver, StatementSetSqlQuery); - FILL_DEFAULT(driver, StatementSetSubstraitPlan); - } - if (version >= ADBC_VERSION_1_1_0) { - auto *driver = reinterpret_cast(raw_driver); - FILL_DEFAULT(driver, ErrorGetDetailCount); - FILL_DEFAULT(driver, ErrorGetDetail); - FILL_DEFAULT(driver, ErrorFromArrayStream); - - FILL_DEFAULT(driver, DatabaseGetOption); - FILL_DEFAULT(driver, DatabaseGetOptionBytes); - FILL_DEFAULT(driver, DatabaseGetOptionDouble); - FILL_DEFAULT(driver, DatabaseGetOptionInt); - FILL_DEFAULT(driver, DatabaseSetOptionBytes); - FILL_DEFAULT(driver, DatabaseSetOptionDouble); - FILL_DEFAULT(driver, DatabaseSetOptionInt); - - FILL_DEFAULT(driver, ConnectionCancel); - FILL_DEFAULT(driver, ConnectionGetOption); - FILL_DEFAULT(driver, ConnectionGetOptionBytes); - FILL_DEFAULT(driver, ConnectionGetOptionDouble); - FILL_DEFAULT(driver, ConnectionGetOptionInt); - FILL_DEFAULT(driver, ConnectionGetStatistics); - FILL_DEFAULT(driver, ConnectionGetStatisticNames); - FILL_DEFAULT(driver, ConnectionSetOptionBytes); - FILL_DEFAULT(driver, ConnectionSetOptionDouble); - FILL_DEFAULT(driver, ConnectionSetOptionInt); - - FILL_DEFAULT(driver, StatementCancel); - FILL_DEFAULT(driver, StatementExecuteSchema); - FILL_DEFAULT(driver, StatementGetOption); - FILL_DEFAULT(driver, StatementGetOptionBytes); - FILL_DEFAULT(driver, StatementGetOptionDouble); - FILL_DEFAULT(driver, StatementGetOptionInt); - FILL_DEFAULT(driver, StatementSetOptionBytes); - FILL_DEFAULT(driver, StatementSetOptionDouble); - FILL_DEFAULT(driver, StatementSetOptionInt); - } - - return ADBC_STATUS_OK; - -#undef FILL_DEFAULT -#undef CHECK_REQUIRED -} diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator.cpp index 977087939..c82302b48 100644 --- a/src/duckdb/src/common/allocator.cpp +++ b/src/duckdb/src/common/allocator.cpp @@ -9,7 +9,6 @@ #include "duckdb/common/types/timestamp.hpp" #include - #ifdef DUCKDB_DEBUG_ALLOCATION #include "duckdb/common/mutex.hpp" #include "duckdb/common/pair.hpp" @@ -35,6 +34,8 @@ namespace duckdb { +constexpr const idx_t Allocator::MAXIMUM_ALLOC_SIZE; + AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { } @@ -254,7 +255,7 @@ static void MallocTrim(idx_t pad) { return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it } - // We succesfully updated LAST_TRIM_TIMESTAMP_MS, we can trim + // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim malloc_trim(pad); #endif } diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index d5acf3698..b5429763b 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -358,7 +358,6 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co } child.children = &root_holder.nested_children_ptr.back()[0]; for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - InitializeChild(*child.children[type_idx], root_holder); root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); diff --git a/src/duckdb/src/common/arrow/arrow_query_result.cpp b/src/duckdb/src/common/arrow/arrow_query_result.cpp index 396a99944..608a0bd32 100644 --- a/src/duckdb/src/common/arrow/arrow_query_result.cpp +++ b/src/duckdb/src/common/arrow/arrow_query_result.cpp @@ -16,10 +16,7 @@ ArrowQueryResult::ArrowQueryResult(StatementType statement_type, StatementProper ArrowQueryResult::ArrowQueryResult(ErrorData error) : QueryResult(QueryResultType::ARROW_RESULT, std::move(error)) { } -unique_ptr ArrowQueryResult::Fetch() { - throw NotImplementedException("Can't 'Fetch' from ArrowQueryResult"); -} -unique_ptr ArrowQueryResult::FetchRaw() { +unique_ptr ArrowQueryResult::FetchInternal() { throw NotImplementedException("Can't 'FetchRaw' from ArrowQueryResult"); } diff --git a/src/duckdb/src/common/arrow/arrow_type_extension.cpp b/src/duckdb/src/common/arrow/arrow_type_extension.cpp index 93979cd36..d3dff923c 100644 --- a/src/duckdb/src/common/arrow/arrow_type_extension.cpp +++ b/src/duckdb/src/common/arrow/arrow_type_extension.cpp @@ -7,6 +7,8 @@ #include "duckdb/common/arrow/schema_metadata.hpp" #include "duckdb/common/types/vector.hpp" +#include "yyjson.hpp" + namespace duckdb { ArrowTypeExtension::ArrowTypeExtension(string extension_name, string arrow_format, @@ -365,6 +367,72 @@ struct ArrowBool8 { } }; +struct ArrowGeometry { + static unique_ptr GetType(const ArrowSchema &schema, const ArrowSchemaMetadata &schema_metadata) { + // Validate extension metadata. This metadata also contains a CRS, which we drop + // because the GEOMETRY type does not implement a CRS at the type level (yet). + const auto extension_metadata = schema_metadata.GetOption(ArrowSchemaMetadata::ARROW_METADATA_KEY); + if (!extension_metadata.empty()) { + unique_ptr doc( + duckdb_yyjson::yyjson_read(extension_metadata.data(), extension_metadata.size(), + duckdb_yyjson::YYJSON_READ_NOFLAG), + duckdb_yyjson::yyjson_doc_free); + if (!doc) { + throw SerializationException("Invalid JSON in GeoArrow metadata"); + } + + duckdb_yyjson::yyjson_val *val = yyjson_doc_get_root(doc.get()); + if (!yyjson_is_obj(val)) { + throw SerializationException("Invalid GeoArrow metadata: not a JSON object"); + } + + duckdb_yyjson::yyjson_val *edges = yyjson_obj_get(val, "edges"); + if (edges && yyjson_is_str(edges) && std::strcmp(yyjson_get_str(edges), "planar") != 0) { + throw NotImplementedException("Can't import non-planar edges"); + } + } + + const auto format = string(schema.format); + if (format == "z") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::NORMAL)); + } + if (format == "Z") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::SUPER_SIZE)); + } + if (format == "vz") { + return make_uniq(LogicalType::GEOMETRY(), + make_uniq(ArrowVariableSizeType::VIEW)); + } + throw InvalidInputException("Arrow extension type \"%s\" not supported for geoarrow.wkb", format.c_str()); + } + + static void PopulateSchema(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &schema, const LogicalType &type, + ClientContext &context, const ArrowTypeExtension &extension) { + ArrowSchemaMetadata schema_metadata; + schema_metadata.AddOption(ArrowSchemaMetadata::ARROW_EXTENSION_NAME, "geoarrow.wkb"); + schema_metadata.AddOption(ArrowSchemaMetadata::ARROW_METADATA_KEY, "{}"); + root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); + schema.metadata = root_holder.metadata_info.back().get(); + + const auto options = context.GetClientProperties(); + if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { + schema.format = "Z"; + } else { + schema.format = "z"; + } + } + + static void ArrowToDuck(ClientContext &, Vector &source, Vector &result, idx_t count) { + Geometry::FromBinary(source, result, count, true); + } + + static void DuckToArrow(ClientContext &context, Vector &source, Vector &result, idx_t count) { + Geometry::ToBinary(source, result, count); + } +}; + void ArrowTypeExtensionSet::Initialize(const DBConfig &config) { // Types that are 1:1 config.RegisterArrowExtension({"arrow.uuid", "w:16", make_shared_ptr(LogicalType::UUID)}); @@ -380,6 +448,11 @@ void ArrowTypeExtensionSet::Initialize(const DBConfig &config) { config.RegisterArrowExtension( {"DuckDB", "time_tz", "w:8", make_shared_ptr(LogicalType::TIME_TZ)}); + config.RegisterArrowExtension( + {"geoarrow.wkb", ArrowGeometry::PopulateSchema, ArrowGeometry::GetType, + make_shared_ptr(LogicalType::GEOMETRY(), LogicalType::BLOB, ArrowGeometry::ArrowToDuck, + ArrowGeometry::DuckToArrow)}); + // Types that are 1:n config.RegisterArrowExtension({"arrow.json", &ArrowJson::PopulateSchema, &ArrowJson::GetType, make_shared_ptr(LogicalType::JSON())}); diff --git a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp index 0636865be..a8b225f75 100644 --- a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp +++ b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp @@ -88,7 +88,7 @@ SinkCombineResultType PhysicalArrowCollector::Combine(ExecutionContext &context, return SinkCombineResultType::FINISHED; } -unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) { +unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) const { auto &gstate = state_p.Cast(); return std::move(gstate.result); } diff --git a/src/duckdb/src/common/bignum.cpp b/src/duckdb/src/common/bignum.cpp index fb3613e88..4414b3a5b 100644 --- a/src/duckdb/src/common/bignum.cpp +++ b/src/duckdb/src/common/bignum.cpp @@ -1,30 +1,39 @@ #include "duckdb/common/bignum.hpp" #include "duckdb/common/types/bignum.hpp" -#include +#include "duckdb/common/printer.hpp" +#include "duckdb/common/to_string.hpp" namespace duckdb { void PrintBits(const char value) { + string result; for (int i = 7; i >= 0; --i) { - std::cout << ((value >> i) & 1); + result += to_string((value >> i) & 1); } + Printer::RawPrint(OutputStream::STREAM_STDOUT, result); } void bignum_t::Print() const { auto ptr = data.GetData(); auto length = data.GetSize(); + string result; for (idx_t i = 0; i < length; ++i) { - PrintBits(ptr[i]); - std::cout << " "; + for (int j = 7; j >= 0; --j) { + result += to_string((ptr[i] >> j) & 1); + } + result += " "; } - std::cout << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, result); } void BignumIntermediate::Print() const { + string result; for (idx_t i = 0; i < size; ++i) { - PrintBits(static_cast(data[i])); - std::cout << " "; + for (int j = 7; j >= 0; --j) { + result += to_string((data[i] >> j) & 1); + } + result += " "; } - std::cout << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, result); } BignumIntermediate::BignumIntermediate(const bignum_t &value) { @@ -232,7 +241,6 @@ void BignumAddition(data_ptr_t result, int64_t result_end, bool is_target_absolu } string_t BignumIntermediate::Negate(Vector &result_vector) const { - auto target = StringVector::EmptyString(result_vector, size + Bignum::BIGNUM_HEADER_SIZE); auto ptr = target.GetDataWriteable(); diff --git a/src/duckdb/src/common/box_renderer.cpp b/src/duckdb/src/common/box_renderer.cpp index 999392df1..4988611fd 100644 --- a/src/duckdb/src/common/box_renderer.cpp +++ b/src/duckdb/src/common/box_renderer.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/box_renderer.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" @@ -7,9 +8,6 @@ #include "utf8proc_wrapper.hpp" namespace duckdb { - -const idx_t BoxRenderer::SPLIT_COLUMN = idx_t(-1); - //===--------------------------------------------------------------------===// // Result Renderer //===--------------------------------------------------------------------===// @@ -46,6 +44,9 @@ void BaseResultRenderer::Render(ResultRenderType render_mode, const string &val) case ResultRenderType::NULL_VALUE: RenderNull(val, value_type); break; + case ResultRenderType::STRING_LITERAL: + RenderStringLiteral(val, value_type); + break; case ResultRenderType::FOOTER: RenderFooter(val); break; @@ -87,47 +88,283 @@ const string &StringResultRenderer::str() { } //===--------------------------------------------------------------------===// -// Box Renderer +// Box Renderer Implementation //===--------------------------------------------------------------------===// -BoxRenderer::BoxRenderer(BoxRendererConfig config_p) : config(std::move(config_p)) { +struct HighlightingAnnotation { + HighlightingAnnotation(ResultRenderType render_mode, idx_t start) : render_mode(render_mode), start(start) { + } + + ResultRenderType render_mode; + idx_t start; +}; + +struct BoxRenderValue { + BoxRenderValue(string text_p, ResultRenderType render_mode, ValueRenderAlignment alignment, + LogicalType type_p = LogicalTypeId::INVALID, optional_idx render_width = optional_idx()) + : text(std::move(text_p)), render_mode(render_mode), alignment(alignment), type(std::move(type_p)), + render_width(render_width) { + } + + string text; + ResultRenderType render_mode; + vector annotations; + ValueRenderAlignment alignment; + LogicalType type; + optional_idx render_width; + bool decomposed = false; +}; + +enum class RenderRowType { ROW_VALUES, SEPARATOR, DIVIDER, FOOTER }; + +struct BoxRendererFooter { + string row_count_str; + string readable_rows_str; + string shown_str; + string column_count_str; + idx_t render_length = 0; + bool must_show_footer = false; + bool show_footer = true; + bool has_hidden_rows = false; + bool has_hidden_columns = false; +}; + +struct BoxRenderRow { + BoxRenderRow(RenderRowType row_type = RenderRowType::ROW_VALUES) // NOLINT: allow implicit conversion + : row_type(row_type) { + } + + RenderRowType row_type; + vector values; +}; + +struct BoxRendererImplementation { + BoxRendererImplementation(BoxRendererConfig &config, ClientContext &context, const vector &names, + const ColumnDataCollection &result, BaseResultRenderer &ss); + +public: + void Render(); + +private: + BoxRendererConfig &config; + ClientContext &context; + vector column_names; + vector result_types; + const ColumnDataCollection &result; + BaseResultRenderer &ss; + vector column_widths; + vector column_boundary_positions; + idx_t total_render_length = 0; + vector render_rows; + BoxRendererFooter footer; + +private: + void RenderValue(const string &value, idx_t column_width, ResultRenderType render_mode, + const vector &annotations, + ValueRenderAlignment alignment = ValueRenderAlignment::MIDDLE, + optional_idx render_width = optional_idx()); + string RenderType(const LogicalType &type); + ValueRenderAlignment TypeAlignment(const LogicalType &type); + string GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r, const LogicalType &type, + ResultRenderType &render_mode); + list FetchRenderCollections(const ColumnDataCollection &result, idx_t top_rows, + idx_t bottom_rows); + list PivotCollections(list input, idx_t row_count); + void ComputeRenderWidths(list &collections, idx_t min_width, idx_t max_width); + void RenderValues(); + void UpdateColumnCountFooter(idx_t column_count, const unordered_set &pruned_columns); + string TruncateValue(const string &value, idx_t column_width, idx_t &pos, idx_t ¤t_render_width); + + void ComputeRowFooter(idx_t row_count, idx_t rendered_rows); + void RenderFooter(idx_t row_count, idx_t column_count); + + string FormatNumber(const string &input); + string ConvertRenderValue(const string &input, const LogicalType &type); + string ConvertRenderValue(const string &input); + void RenderLayoutLine(const char *layout, const char *boundary, const char *left_corner, const char *right_corner); + //! Try to format a large number in a readable way (e.g. 1234567 -> 1.23 million) + string TryFormatLargeNumber(const string &numeric); + + bool CanPrettyPrint(const LogicalType &type); + bool CanHighlight(const LogicalType &type); + void PrettyPrintValue(BoxRenderValue &render_value, idx_t max_rows, idx_t max_width); + void HighlightValue(BoxRenderValue &render_value); +}; + +BoxRendererImplementation::BoxRendererImplementation(BoxRendererConfig &config, ClientContext &context, + const vector &names, const ColumnDataCollection &result, + BaseResultRenderer &ss) + : config(config), context(context), column_names(names), result(result), ss(ss) { + result_types = result.Types(); } -string BoxRenderer::ToString(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - StringResultRenderer ss; - Render(context, names, result, ss); - return ss.str(); +void BoxRendererImplementation::ComputeRowFooter(idx_t row_count, idx_t rendered_rows) { + footer.column_count_str = to_string(result.ColumnCount()) + " column"; + if (result.ColumnCount() > 1) { + footer.column_count_str += "s"; + } + footer.row_count_str = FormatNumber(to_string(row_count)) + " rows"; + bool has_limited_rows = config.limit > 0 && row_count == config.limit; + if (has_limited_rows) { + footer.row_count_str = "? rows"; + } + if (config.large_number_rendering == LargeNumberRendering::FOOTER && !has_limited_rows) { + footer.readable_rows_str = TryFormatLargeNumber(to_string(row_count)); + if (!footer.readable_rows_str.empty()) { + footer.readable_rows_str += " rows"; + } + } + footer.has_hidden_rows = rendered_rows < row_count; + if (footer.has_hidden_rows) { + if (has_limited_rows) { + footer.shown_str += ">" + FormatNumber(to_string(config.limit - 1)) + " rows, "; + } + footer.shown_str += FormatNumber(to_string(rendered_rows)) + " shown"; + } + footer.must_show_footer = has_limited_rows || footer.has_hidden_rows || row_count == 0; + footer.render_length = MaxValue(MaxValue(footer.row_count_str.size(), footer.shown_str.size() + 2), + footer.readable_rows_str.size() + 2) + + 4; } -void BoxRenderer::Print(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - Printer::Print(ToString(context, names, result)); +void BoxRendererImplementation::UpdateColumnCountFooter(idx_t column_count, + const unordered_set &pruned_columns) { + if (pruned_columns.empty()) { + // no pruned columns - no need to update the footer + return; + } + if (config.render_mode == RenderMode::COLUMNS) { + // in columns mode - pruned columns really means pruned rows + footer.has_hidden_rows = true; + idx_t shown_row_count = column_count - pruned_columns.size(); + footer.shown_str = to_string(shown_row_count - 2) + " shown"; + } else { + footer.has_hidden_columns = true; + idx_t shown_column_count = column_count - pruned_columns.size(); + footer.column_count_str += " (" + to_string(shown_column_count) + " shown)"; + } +} + +void BoxRendererImplementation::Render() { + if (result.ColumnCount() != column_names.size()) { + throw InternalException("Error in BoxRenderer::Render - unaligned columns and names"); + } + auto max_width = config.max_width; + if (max_width == 0) { + if (Printer::IsTerminal(OutputStream::STREAM_STDOUT)) { + max_width = Printer::TerminalWidth(); + } else { + max_width = 120; + } + } + // we do not support max widths under 80 + max_width = MaxValue(80, max_width); + + // figure out how many/which rows to render + idx_t row_count = result.Count(); + idx_t rows_to_render = MinValue(row_count, config.max_rows); + if (row_count <= config.max_rows + 3) { + // hiding rows adds 3 extra rows + // so hiding rows makes no sense if we are only slightly over the limit + // if we are 1 row over the limit hiding rows will actually increase the number of lines we display! + // in this case render all the rows + rows_to_render = row_count; + } + idx_t top_rows; + idx_t bottom_rows; + if (rows_to_render == row_count) { + top_rows = row_count; + bottom_rows = 0; + } else { + top_rows = rows_to_render / 2 + (rows_to_render % 2 != 0 ? 1 : 0); + bottom_rows = rows_to_render - top_rows; + } + ComputeRowFooter(row_count, top_rows + bottom_rows); + + // fetch the top and bottom render collections from the result + auto collections = FetchRenderCollections(result, top_rows, bottom_rows); + if (config.render_mode == RenderMode::COLUMNS && rows_to_render > 0) { + collections = PivotCollections(std::move(collections), row_count); + } + + // for each column, figure out the width + // start off by figuring out the name of the header by looking at the column name and column type + idx_t min_width = footer.must_show_footer ? footer.render_length : 0; + ComputeRenderWidths(collections, min_width, max_width); + + // render boundaries for the individual columns + for (idx_t c = 0; c < column_widths.size(); c++) { + idx_t render_boundary; + if (c == 0) { + render_boundary = column_widths[c] + 2; + } else { + render_boundary = column_boundary_positions[c - 1] + column_widths[c] + 3; + } + column_boundary_positions.push_back(render_boundary); + } + + // now begin rendering + // render the box + RenderValues(); + + // render the row count and column count + idx_t column_count = result_types.size(); + RenderFooter(row_count, column_count); +} + +string BoxRenderer::TruncateValue(const string &value, idx_t column_width, idx_t &pos, idx_t ¤t_render_width) { + idx_t start_pos = pos; + while (pos < value.size()) { + if (value[pos] == '\n') { + // newline character - stop rendering for this line - but skip the newline + idx_t render_pos = pos; + pos++; + return value.substr(start_pos, render_pos - start_pos); + } + // check if this character fits... + auto char_size = Utf8Proc::RenderWidth(value.c_str(), value.size(), pos); + if (current_render_width + char_size > column_width) { + // it doesn't! stop + break; + } + // it does! move to the next character + current_render_width += char_size; + pos = Utf8Proc::NextGraphemeCluster(value.c_str(), value.size(), pos); + } + return value.substr(start_pos, pos - start_pos); +} + +string BoxRendererImplementation::TruncateValue(const string &value, idx_t column_width, idx_t &pos, + idx_t ¤t_render_width) { + return BoxRenderer::TruncateValue(value, column_width, pos, current_render_width); } -void BoxRenderer::RenderValue(BaseResultRenderer &ss, const string &value, idx_t column_width, - ResultRenderType render_mode, ValueRenderAlignment alignment) { - auto render_width = Utf8Proc::RenderWidth(value); +void BoxRendererImplementation::RenderValue(const string &value, idx_t column_width, ResultRenderType render_mode, + const vector &annotations, + ValueRenderAlignment alignment, optional_idx render_width_input) { + idx_t render_width; + if (render_width_input.IsValid()) { + render_width = render_width_input.GetIndex(); + if (render_width != Utf8Proc::RenderWidth(value)) { + throw InternalException("Misaligned render width provided for string \"%s\"", value); + } + } else { + render_width = Utf8Proc::RenderWidth(value); + } - const string *render_value = &value; + const_reference render_value(value); string small_value; + idx_t max_render_pos = value.size(); if (render_width > column_width) { // the string is too large to fit in this column! // the size of this column must have been reduced // figure out how much of this value we can render idx_t pos = 0; idx_t current_render_width = config.DOTDOTDOT_LENGTH; - while (pos < value.size()) { - // check if this character fits... - auto char_size = Utf8Proc::RenderWidth(value.c_str(), value.size(), pos); - if (current_render_width + char_size >= column_width) { - // it doesn't! stop - break; - } - // it does! move to the next character - current_render_width += char_size; - pos = Utf8Proc::NextGraphemeCluster(value.c_str(), value.size(), pos); - } - small_value = value.substr(0, pos) + config.DOTDOTDOT; - render_value = &small_value; + small_value = TruncateValue(value, column_width, pos, current_render_width); + max_render_pos = small_value.size(); + small_value += config.DOTDOTDOT; render_width = current_render_width; + render_value = const_reference(small_value); } auto padding_count = (column_width - render_width) + 2; idx_t lpadding; @@ -150,11 +387,29 @@ void BoxRenderer::RenderValue(BaseResultRenderer &ss, const string &value, idx_t } ss << config.VERTICAL; ss << string(lpadding, ' '); - ss.Render(render_mode, *render_value); + if (!annotations.empty()) { + // if we have annotations split up the rendering between annotations + idx_t pos = 0; + ResultRenderType active_render_mode = render_mode; + for (auto &annotation : annotations) { + if (annotation.start >= max_render_pos) { + break; + } + auto render_end = MinValue(max_render_pos, annotation.start); + ss.Render(active_render_mode, render_value.get().substr(pos, render_end - pos)); + active_render_mode = annotation.render_mode; + pos = render_end; + } + if (pos < render_value.get().size()) { + ss.Render(active_render_mode, render_value.get().substr(pos, render_value.get().size() - pos)); + } + } else { + ss.Render(render_mode, render_value.get()); + } ss << string(rpadding, ' '); } -string BoxRenderer::RenderType(const LogicalType &type) { +string BoxRendererImplementation::RenderType(const LogicalType &type) { if (type.HasAlias()) { return StringUtil::Lower(type.ToString()); } @@ -188,7 +443,7 @@ string BoxRenderer::RenderType(const LogicalType &type) { } } -ValueRenderAlignment BoxRenderer::TypeAlignment(const LogicalType &type) { +ValueRenderAlignment BoxRendererImplementation::TypeAlignment(const LogicalType &type) { switch (type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::SMALLINT: @@ -209,7 +464,7 @@ ValueRenderAlignment BoxRenderer::TypeAlignment(const LogicalType &type) { } } -string BoxRenderer::TryFormatLargeNumber(const string &numeric) { +string BoxRenderer::TryFormatLargeNumber(const string &numeric, char decimal_sep) { // we only return a readable rendering if the number is > 1 million if (numeric.size() <= 5) { // number too small for sure @@ -270,16 +525,19 @@ string BoxRenderer::TryFormatLargeNumber(const string &numeric) { result += "-"; } result += decimal_str.substr(0, decimal_str.size() - 2); - result += config.decimal_separator == '\0' ? '.' : config.decimal_separator; + result += decimal_sep == '\0' ? '.' : decimal_sep; result += decimal_str.substr(decimal_str.size() - 2, 2); result += " "; result += unit; return result; } -list BoxRenderer::FetchRenderCollections(ClientContext &context, - const ColumnDataCollection &result, idx_t top_rows, - idx_t bottom_rows) { +string BoxRendererImplementation::TryFormatLargeNumber(const string &numeric) { + return BoxRenderer::TryFormatLargeNumber(numeric, config.decimal_separator); +} + +list BoxRendererImplementation::FetchRenderCollections(const ColumnDataCollection &result, + idx_t top_rows, idx_t bottom_rows) { auto column_count = result.ColumnCount(); vector varchar_types; for (idx_t c = 0; c < column_count; c++) { @@ -309,6 +567,9 @@ list BoxRenderer::FetchRenderCollections(ClientContext &co idx_t chunk_idx = 0; idx_t row_idx = 0; while (row_idx < top_rows) { + if (context.IsInterrupted()) { + break; + } fetch_result.Reset(); insert_result.Reset(); // fetch the next chunk @@ -353,6 +614,8 @@ list BoxRenderer::FetchRenderCollections(ClientContext &co } insert_result.SetCardinality(1); top_collection.Append(insert_result); + } else { + config.large_number_rendering = LargeNumberRendering::NONE; } } @@ -364,6 +627,9 @@ list BoxRenderer::FetchRenderCollections(ClientContext &co row_idx = 0; chunk_idx = result.ChunkCount() - 1; while (row_idx < bottom_rows) { + if (context.IsInterrupted()) { + break; + } fetch_result.Reset(); insert_result.Reset(); // fetch the next chunk @@ -390,9 +656,8 @@ list BoxRenderer::FetchRenderCollections(ClientContext &co return collections; } -list BoxRenderer::PivotCollections(ClientContext &context, list input, - vector &column_names, - vector &result_types, idx_t row_count) { +list BoxRendererImplementation::PivotCollections(list input, + idx_t row_count) { auto &top = input.front(); auto &bottom = input.back(); @@ -428,6 +693,9 @@ list BoxRenderer::PivotCollections(ClientContext &context, row_chunk.SetValue(current_index++, row_index, RenderType(result_types[c])); for (auto &collection : input) { for (auto &chunk : collection.Chunks(column_ids)) { + if (context.IsInterrupted()) { + break; + } for (idx_t r = 0; r < chunk.size(); r++) { row_chunk.SetValue(current_index++, row_index, chunk.GetValue(0, r)); } @@ -444,7 +712,7 @@ list BoxRenderer::PivotCollections(ClientContext &context, return result; } -string BoxRenderer::ConvertRenderValue(const string &input) { +string BoxRendererImplementation::ConvertRenderValue(const string &input) { string result; result.reserve(input.size()); for (idx_t c = 0; c < input.size(); c++) { @@ -496,7 +764,7 @@ string BoxRenderer::ConvertRenderValue(const string &input) { return result; } -string BoxRenderer::FormatNumber(const string &input) { +string BoxRendererImplementation::FormatNumber(const string &input) { if (config.large_number_rendering == LargeNumberRendering::ALL) { // when large number rendering is set to ALL, we try to format all numbers as large numbers auto number = TryFormatLargeNumber(input); @@ -538,7 +806,7 @@ string BoxRenderer::FormatNumber(const string &input) { return result; } -string BoxRenderer::ConvertRenderValue(const string &input, const LogicalType &type) { +string BoxRendererImplementation::ConvertRenderValue(const string &input, const LogicalType &type) { switch (type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::SMALLINT: @@ -559,8 +827,8 @@ string BoxRenderer::ConvertRenderValue(const string &input, const LogicalType &t } } -string BoxRenderer::GetRenderValue(BaseResultRenderer &ss, ColumnDataRowCollection &rows, idx_t c, idx_t r, - const LogicalType &type, ResultRenderType &render_mode) { +string BoxRendererImplementation::GetRenderValue(ColumnDataRowCollection &rows, idx_t c, idx_t r, + const LogicalType &type, ResultRenderType &render_mode) { try { render_mode = ResultRenderType::VALUE; ss.SetValueType(type); @@ -575,365 +843,1166 @@ string BoxRenderer::GetRenderValue(BaseResultRenderer &ss, ColumnDataRowCollecti } } -vector BoxRenderer::ComputeRenderWidths(const vector &names, const vector &result_types, - list &collections, idx_t min_width, - idx_t max_width, vector &column_map, idx_t &total_length) { - auto column_count = result_types.size(); +struct JSONParser { +public: + virtual ~JSONParser() = default; - vector widths; - widths.reserve(column_count); - for (idx_t c = 0; c < column_count; c++) { - auto name_width = Utf8Proc::RenderWidth(ConvertRenderValue(names[c])); - auto type_width = Utf8Proc::RenderWidth(RenderType(result_types[c])); - widths.push_back(MaxValue(name_width, type_width)); - } +protected: + enum class JSONState { REGULAR, IN_QUOTE, ESCAPE }; - // now iterate over the data in the render collection and find out the true max width - for (auto &collection : collections) { - for (auto &chunk : collection.Chunks()) { - for (idx_t c = 0; c < column_count; c++) { - auto string_data = FlatVector::GetData(chunk.data[c]); - for (idx_t r = 0; r < chunk.size(); r++) { - string render_value; - if (FlatVector::IsNull(chunk.data[c], r)) { - render_value = config.null_value; - } else { - render_value = ConvertRenderValue(string_data[r].GetString(), result_types[c]); - } - auto render_width = Utf8Proc::RenderWidth(render_value); - widths[c] = MaxValue(render_width, widths[c]); - } - } + struct Separator { + Separator(char sep) // NOLINT: allow implicit conversion + : sep(sep), inlined(false) { } + + char sep; + bool inlined = false; + }; + +public: + bool Process(const string &value); + +protected: + virtual void HandleNull() { + } + virtual void HandleBracketOpen(char bracket) { + } + virtual void HandleBracketClose(char bracket) { + } + virtual void HandleQuoteStart(char quote) { + } + virtual void HandleQuoteEnd(char quote) { + } + virtual void HandleComma(char comma) { + } + virtual void HandleColon() { + } + virtual void HandleCharacter(char c) { + } + virtual void HandleEscapeStart(char c) { + } + virtual void Finish() { } - // figure out the total length - // we start off with a pipe (|) - total_length = 1; - for (idx_t c = 0; c < widths.size(); c++) { - // each column has a space at the beginning, and a space plus a pipe (|) at the end - // hence + 3 - total_length += widths[c] + 3; +protected: + bool SeparatorIsMatching(Separator &sep, char closing_sep); + idx_t Depth() const { + return separators.size(); } - if (total_length < min_width) { - // if there are hidden rows we should always display that - // stretch up the first column until we have space to show the row count - widths[0] += min_width - total_length; - total_length = min_width; + +protected: + JSONState state = JSONState::REGULAR; + vector separators; + idx_t pos = 0; + bool success = true; +}; + +bool JSONParser::SeparatorIsMatching(Separator &sep, char closing_sep) { + if (sep.sep == '{' && closing_sep == '}') { + return true; } - // now we need to constrain the length - unordered_set pruned_columns; - if (total_length > max_width) { - // before we remove columns, check if we can just reduce the size of columns - for (auto &w : widths) { - if (w > config.max_col_width) { - auto max_diff = w - config.max_col_width; - if (total_length - max_diff <= max_width) { - // if we reduce the size of this column we fit within the limits! - // reduce the width exactly enough so that the box fits - w -= total_length - max_width; - total_length = max_width; - break; - } else { - // reducing the width of this column does not make the result fit - // reduce the column width by the maximum amount anyway - w = config.max_col_width; - total_length -= max_diff; + if (sep.sep == '[' && closing_sep == ']') { + return true; + } + return false; +} + +bool IsWhitespaceEscape(const char c) { + // \n and \t are whitespace escapes + return c == 'n' || c == 't'; +} + +bool JSONParser::Process(const string &value) { + separators.clear(); + state = JSONState::REGULAR; + char quote_char = '"'; + bool can_parse_value = false; + pos = 0; + for (; success && pos < value.size(); pos++) { + auto c = value[pos]; + if (state == JSONState::REGULAR) { + if (can_parse_value) { + // check if this is "null" + if (pos + 4 < value.size() && StringUtil::CharacterToLower(c) == 'n' && + StringUtil::CharacterToLower(value[pos + 1]) == 'u' && + StringUtil::CharacterToLower(value[pos + 2]) == 'l' && + StringUtil::CharacterToLower(value[pos + 3]) == 'l') { + HandleNull(); + pos += 3; + continue; } } - } - - if (total_length > max_width) { - // the total length is still too large - // we need to remove columns! - // first, we add 6 characters to the total length - // this is what we need to add the "..." in the middle - total_length += 3 + config.DOTDOTDOT_LENGTH; - // now select columns to prune - // we select columns in zig-zag order starting from the middle - // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc - int64_t offset = 0; - while (total_length > max_width) { - auto c = NumericCast(NumericCast(column_count) / 2 + offset); - total_length -= widths[c] + 3; - pruned_columns.insert(c); - if (offset >= 0) { - offset = -offset - 1; - } else { - offset = -offset; + switch (c) { + case '[': + case '{': { + // add a newline and indentation based on the separator count + separators.push_back(c); + HandleBracketOpen(c); + can_parse_value = c == '['; + break; + } + case '}': + case ']': { + // closing bracket - move to next line and pop back the separator + if (separators.empty() || !SeparatorIsMatching(separators.back(), c)) { + throw InternalException("Failed to parse JSON string %s - invalid JSON", value); } + separators.pop_back(); + HandleBracketClose(c); + break; } - } - } - - bool added_split_column = false; - vector new_widths; - for (idx_t c = 0; c < column_count; c++) { - if (pruned_columns.find(c) == pruned_columns.end()) { - column_map.push_back(c); - new_widths.push_back(widths[c]); - } else { - if (!added_split_column) { - // "..." - column_map.push_back(SPLIT_COLUMN); - new_widths.push_back(config.DOTDOTDOT_LENGTH); - added_split_column = true; + case '"': + case '\'': + HandleQuoteStart(c); + quote_char = c; + state = JSONState::IN_QUOTE; + break; + case ',': + // comma - move to next line + HandleComma(c); + break; + case ':': + HandleColon(); + can_parse_value = true; + break; + case '\\': + // skip literal "\n" and "\t" (these were escaped previously by our rendering algorithm) + if (pos + 1 < value.size() && IsWhitespaceEscape(value[pos + 1])) { + pos++; + break; + } + // if this is not a whitespace escape just handle it + HandleCharacter(c); + break; + case ' ': + case '\t': + case '\n': + // skip whitespace + break; + default: + HandleCharacter(c); + break; + } + } else if (state == JSONState::IN_QUOTE) { + if (c == quote_char) { + // break out of quotes + state = JSONState::REGULAR; + HandleQuoteEnd(c); + } else if (c == '\\') { + // escape + state = JSONState::ESCAPE; + HandleEscapeStart(c); + } else { + HandleCharacter(c); } + } else if (state == JSONState::ESCAPE) { + state = JSONState::IN_QUOTE; + HandleCharacter(c); + } else { + throw InternalException("Invalid json state"); } } - return new_widths; + if (!success) { + return false; + } + Finish(); + return true; } -void BoxRenderer::RenderHeader(const vector &names, const vector &result_types, - const vector &column_map, const vector &widths, - const vector &boundaries, idx_t total_length, bool has_results, - BaseResultRenderer &ss) { - auto column_count = column_map.size(); - // render the top line - ss << config.LTCORNER; - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < column_count && k == boundaries[column_index]) { - ss << config.TMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } +enum class JSONFormattingMode { STANDARD, COMPACT_VERTICAL, COMPACT_HORIZONTAL }; + +enum class JSONComponentType { BRACKET_OPEN, BRACKET_CLOSE, LITERAL, COLON, COMMA, NULL_VALUE }; + +enum class JSONFormattingResult { SUCCESS, TOO_MANY_ROWS, TOO_WIDE }; + +struct JSONComponent { + JSONComponent(JSONComponentType type, string text_p) : type(type), text(std::move(text_p)) { } - ss << config.RTCORNER; - ss << '\n'; - // render the header names - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string name; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - render_mode = ResultRenderType::LAYOUT; - name = config.DOTDOTDOT; + JSONComponentType type; + string text; +}; + +struct JSONFormatter : public JSONParser { +public: + explicit JSONFormatter() { + } + + static void FormatValue(BoxRenderValue &render_value, idx_t max_rows, idx_t max_width) { + // process the components + JSONFormatter formatter; + formatter.Process(render_value.text); + + idx_t indentation_size = 2; + + auto result = + formatter.TryFormat(JSONFormattingMode::STANDARD, render_value, max_rows, max_width, indentation_size); + if (result == JSONFormattingResult::SUCCESS) { + return; + } + // if we exceeded the max row count - try in compact mode + JSONFormattingMode mode; + if (result == JSONFormattingResult::TOO_WIDE) { + // reduce indentation size if the result was too wide + mode = JSONFormattingMode::COMPACT_HORIZONTAL; + indentation_size = 1; } else { - render_mode = ResultRenderType::COLUMN_NAME; - name = ConvertRenderValue(names[column_idx]); + mode = JSONFormattingMode::COMPACT_VERTICAL; + } + result = formatter.TryFormat(mode, render_value, max_rows, max_width, indentation_size); + if (result == JSONFormattingResult::SUCCESS) { + return; } - RenderValue(ss, name, widths[c], render_mode); } - ss << config.VERTICAL; - ss << '\n'; - // render the types - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string type; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - render_mode = ResultRenderType::LAYOUT; - } else { - render_mode = ResultRenderType::COLUMN_TYPE; - type = RenderType(result_types[column_idx]); - } - RenderValue(ss, type, widths[c], render_mode); - } - ss << config.VERTICAL; - ss << '\n'; +protected: + void HandleNull() override { + components.emplace_back(JSONComponentType::NULL_VALUE, "null"); } - // render the line under the header - ss << config.LMIDDLE; - column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < column_count && k == boundaries[column_index]) { - ss << (has_results ? config.MIDDLE : config.DMIDDLE); - column_index++; - } else { - ss << config.HORIZONTAL; - } + void HandleBracketOpen(char bracket) override { + components.emplace_back(JSONComponentType::BRACKET_OPEN, string(1, bracket)); } - ss << config.RMIDDLE; - ss << '\n'; -} -void BoxRenderer::RenderValues(const list &collections, const vector &column_map, - const vector &widths, const vector &result_types, - BaseResultRenderer &ss) { - auto &top_collection = collections.front(); - auto &bottom_collection = collections.back(); - // render the top rows - auto top_rows = top_collection.Count(); - auto bottom_rows = bottom_collection.Count(); - auto column_count = column_map.size(); + void HandleBracketClose(char bracket) override { + components.emplace_back(JSONComponentType::BRACKET_CLOSE, string(1, bracket)); + } - bool large_number_footer = config.large_number_rendering == LargeNumberRendering::FOOTER; - vector alignments; - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - if (column_idx == SPLIT_COLUMN) { - alignments.push_back(ValueRenderAlignment::MIDDLE); - } else if (large_number_footer && result_types[column_idx].IsNumeric()) { - alignments.push_back(ValueRenderAlignment::MIDDLE); - } else { - alignments.push_back(TypeAlignment(result_types[column_idx])); - } - } + void HandleQuoteStart(char quote) override { + AddLiteralCharacter(quote); } - auto rows = top_collection.GetRows(); - for (idx_t r = 0; r < top_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - render_mode = ResultRenderType::LAYOUT; - } else { - str = GetRenderValue(ss, rows, column_idx, r, result_types[column_idx], render_mode); + void HandleQuoteEnd(char quote) override { + AddLiteralCharacter(quote); + } + + void HandleComma(char comma) override { + components.emplace_back(JSONComponentType::COMMA, ","); + } + + void HandleColon() override { + components.emplace_back(JSONComponentType::COLON, ":"); + } + + void HandleCharacter(char c) override { + AddLiteralCharacter(c); + } + + void HandleEscapeStart(char c) override { + AddLiteralCharacter(c); + } + + void AddLiteralCharacter(char c) { + if (components.empty() || components.back().type != JSONComponentType::LITERAL) { + components.emplace_back(JSONComponentType::LITERAL, ""); + } + components.back().text += c; + } + + struct FormatState { + JSONFormattingMode mode; + string result; + idx_t component_idx = 0; + idx_t row_count = 0; + idx_t line_length = 0; + idx_t depth = 0; + idx_t max_rows; + idx_t max_width; + idx_t indentation_size = 2; + JSONFormattingResult format_result = JSONFormattingResult::SUCCESS; + }; + + bool LiteralFits(FormatState &format_state, idx_t render_width) { + auto &line_length = format_state.line_length; + if (line_length + render_width > format_state.max_width) { + return false; + } + return true; + } + + bool LiteralFits(FormatState &format_state, const string &text) { + idx_t render_width = Utf8Proc::RenderWidth(text); + return LiteralFits(format_state, render_width); + } + + void AddLiteral(FormatState &format_state, const string &text, bool skip_adding_if_does_not_fit = false) { + auto &result = format_state.result; + auto &line_length = format_state.line_length; + idx_t render_width = Utf8Proc::RenderWidth(text); + if (!LiteralFits(format_state, render_width)) { + if (skip_adding_if_does_not_fit) { + return; + } + AddNewline(format_state); + if (format_state.format_result != JSONFormattingResult::SUCCESS) { + return; } - ValueRenderAlignment alignment; - if (config.render_mode == RenderMode::ROWS) { - alignment = alignments[c]; - if (large_number_footer && r == 1) { - // render readable numbers with highlighting of a NULL value - render_mode = ResultRenderType::NULL_VALUE; + } + result += text; + line_length += render_width; + if (line_length > format_state.max_width) { + format_state.format_result = JSONFormattingResult::TOO_WIDE; + } + } + void AddSpace(FormatState &format_state) { + AddLiteral(format_state, " ", true); + } + void AddNewline(FormatState &format_state) { + auto &result = format_state.result; + auto &depth = format_state.depth; + auto &row_count = format_state.row_count; + auto &line_length = format_state.line_length; + result += '\n'; + result += string(depth, ' '); + row_count++; + if (row_count > format_state.max_rows) { + format_state.format_result = JSONFormattingResult::TOO_MANY_ROWS; + return; + } + line_length = depth; + if (line_length > format_state.max_width) { + format_state.format_result = JSONFormattingResult::TOO_WIDE; + } + } + + enum class InlineMode { STANDARD, INLINED_SINGLE_LINE, INLINED_MULTI_LINE }; + + void FormatComponent(FormatState &format_state, JSONComponent &component, InlineMode inline_mode) { + auto &depth = format_state.depth; + auto &line_length = format_state.line_length; + auto &max_width = format_state.max_width; + auto &c = format_state.component_idx; + switch (component.type) { + case JSONComponentType::BRACKET_OPEN: { + depth += component.text == "{" ? format_state.indentation_size : 1; + AddLiteral(format_state, component.text); + if (inline_mode == InlineMode::STANDARD) { + // not inlined + // look forward until the corresponding bracket open - can we inline and not exceed the column width? + idx_t peek_depth = 0; + idx_t render_size = line_length; + idx_t peek_idx; + InlineMode inline_child_mode = InlineMode::STANDARD; + for (peek_idx = c + 1; peek_idx < components.size() && render_size <= max_width; peek_idx++) { + auto &peek_component = components[peek_idx]; + if (peek_component.type == JSONComponentType::BRACKET_OPEN) { + peek_depth++; + } else if (peek_component.type == JSONComponentType::BRACKET_CLOSE) { + if (peek_depth == 0) { + // close! + if (render_size + 1 < max_width) { + // fits within a single line - inline on a single line + inline_child_mode = InlineMode::INLINED_SINGLE_LINE; + } + break; + } + peek_depth--; + } + render_size += Utf8Proc::RenderWidth(peek_component.text); + if (peek_component.type == JSONComponentType::COMMA || + peek_component.type == JSONComponentType::COLON) { + render_size++; + } } - } else { - switch (c) { - case 0: - render_mode = ResultRenderType::COLUMN_NAME; - break; - case 1: - render_mode = ResultRenderType::COLUMN_TYPE; - break; - default: - render_mode = ResultRenderType::VALUE; - break; + if (component.text == "[") { + // for arrays - we always inline them UNLESS there are complex objects INSIDE of the bracket + // scan forward until the end of the array to figure out if this is true or not + for (peek_idx = c + 1; peek_idx < components.size(); peek_idx++) { + auto &peek_component = components[peek_idx]; + peek_depth = 0; + if (peek_component.type == JSONComponentType::BRACKET_OPEN) { + if (peek_component.text == "{") { + // nested structure within the array + break; + } + peek_depth++; + } + if (peek_component.type == JSONComponentType::BRACKET_CLOSE) { + if (peek_depth == 0) { + inline_child_mode = InlineMode::INLINED_MULTI_LINE; + break; + } + peek_depth--; + } + } } - if (c < 2) { - alignment = ValueRenderAlignment::LEFT; - } else if (c == SPLIT_COLUMN) { - alignment = ValueRenderAlignment::MIDDLE; - } else { - alignment = ValueRenderAlignment::RIGHT; + if (inline_child_mode != InlineMode::STANDARD) { + // we can inline! do it + for (idx_t inline_idx = c + 1; inline_idx <= peek_idx; inline_idx++) { + auto &inline_component = components[inline_idx]; + if (inline_child_mode == InlineMode::INLINED_MULTI_LINE && inline_idx + 1 <= peek_idx) { + auto &next_component = components[inline_idx + 1]; + if (next_component.type == JSONComponentType::COMMA || + next_component.type == JSONComponentType::BRACKET_CLOSE) { + if (!LiteralFits(format_state, inline_component.text + next_component.text)) { + AddNewline(format_state); + } + } + } + FormatComponent(format_state, inline_component, inline_child_mode); + } + c = peek_idx; + return; + } + if (format_state.mode == JSONFormattingMode::COMPACT_VERTICAL) { + // we can't inline - but is the next token a bracket open? + if (c + 1 < components.size() && components[c + 1].type == JSONComponentType::BRACKET_OPEN) { + // it is! that bracket open will add a newline - we don't need to do it here + return; + } + } + AddNewline(format_state); + } + break; + } + case JSONComponentType::BRACKET_CLOSE: { + idx_t depth_diff = component.text == "}" ? format_state.indentation_size : 1; + if (depth < depth_diff) { + // shouldn't happen - but guard against underflows + depth = 0; + } else { + depth -= depth_diff; + } + if (inline_mode == InlineMode::STANDARD) { + AddNewline(format_state); + } + AddLiteral(format_state, component.text); + break; + } + case JSONComponentType::COMMA: + case JSONComponentType::COLON: + AddLiteral(format_state, component.text); + bool always_inline; + if (format_state.mode == JSONFormattingMode::COMPACT_HORIZONTAL) { + // if we are trying to compact horizontally - don't inline colons unless it fits + always_inline = false; + } else { + // in normal processing we always inline colons + always_inline = component.type == JSONComponentType::COLON; + } + if (inline_mode != InlineMode::STANDARD || always_inline) { + AddSpace(format_state); + } else { + if (format_state.mode != JSONFormattingMode::STANDARD) { + // if we are not inlining in compact mode, try to inline until the next comma + idx_t peek_depth = 0; + idx_t render_size = line_length + 1; + idx_t peek_idx; + bool inline_comma = false; + for (peek_idx = c + 1; peek_idx < components.size() && render_size <= max_width; peek_idx++) { + auto &peek_component = components[peek_idx]; + if (peek_component.type == JSONComponentType::BRACKET_OPEN) { + peek_depth++; + } else if (peek_component.type == JSONComponentType::BRACKET_CLOSE) { + if (peek_depth == 0) { + inline_comma = render_size + 1 < max_width; + break; + } + peek_depth--; + } + if (peek_depth == 0 && peek_component.type == JSONComponentType::COMMA) { + // found the next comma - inline! + inline_comma = render_size + 2 <= max_width; + break; + } + render_size += Utf8Proc::RenderWidth(peek_component.text); + if (peek_component.type == JSONComponentType::COMMA || + peek_component.type == JSONComponentType::COLON) { + render_size++; + } + } + if (inline_comma) { + // we can inline until the next comma! do it + AddSpace(format_state); + for (idx_t inline_idx = c + 1; inline_idx < peek_idx; inline_idx++) { + auto &inline_component = components[inline_idx]; + FormatComponent(format_state, inline_component, InlineMode::INLINED_SINGLE_LINE); + } + c = peek_idx - 1; + return; + } } + AddNewline(format_state); } - RenderValue(ss, str, widths[c], render_mode, alignment); + break; + case JSONComponentType::NULL_VALUE: + case JSONComponentType::LITERAL: + AddLiteral(format_state, component.text); + break; + default: + throw InternalException("Unsupported JSON component type"); } - ss << config.VERTICAL; - ss << '\n'; } - if (bottom_rows > 0) { - if (config.render_mode == RenderMode::COLUMNS) { - throw InternalException("Columns render mode does not support bottom rows"); + JSONFormattingResult TryFormat(JSONFormattingMode mode, BoxRenderValue &render_value, idx_t max_rows, + idx_t max_width, idx_t indentation_size = 2) { + FormatState format_state; + format_state.mode = mode; + format_state.max_rows = max_rows; + format_state.max_width = max_width; + format_state.indentation_size = indentation_size; + for (format_state.component_idx = 0; format_state.component_idx < components.size() && + format_state.format_result == JSONFormattingResult::SUCCESS; + format_state.component_idx++) { + auto &component = components[format_state.component_idx]; + FormatComponent(format_state, component, InlineMode::STANDARD); + } + + if (format_state.format_result != JSONFormattingResult::SUCCESS) { + return format_state.format_result; + } + render_value.text = format_state.result; + return JSONFormattingResult::SUCCESS; + } + +protected: + vector components; +}; + +struct JSONHighlighter : public JSONParser { +public: + explicit JSONHighlighter(BoxRenderValue &render_value) : render_value(render_value) { + } + +protected: + void HandleNull() override { + render_value.annotations.emplace_back(ResultRenderType::NULL_VALUE, pos); + render_value.annotations.emplace_back(render_value.render_mode, pos + 4); + } + + void HandleQuoteStart(char quote) override { + render_value.annotations.emplace_back(ResultRenderType::STRING_LITERAL, pos); + } + + void HandleQuoteEnd(char quote) override { + render_value.annotations.emplace_back(render_value.render_mode, pos + 1); + } + +protected: + BoxRenderValue &render_value; +}; + +bool BoxRendererImplementation::CanPrettyPrint(const LogicalType &type) { + return type.IsJSONType() || type.IsNested(); +} + +bool BoxRendererImplementation::CanHighlight(const LogicalType &type) { + return type.IsJSONType() || type.IsNested(); +} + +void BoxRendererImplementation::PrettyPrintValue(BoxRenderValue &render_value, idx_t max_rows, idx_t max_width) { + if (!CanPrettyPrint(render_value.type)) { + return; + } + JSONFormatter::FormatValue(render_value, max_rows, max_width); +} + +void BoxRendererImplementation::HighlightValue(BoxRenderValue &render_value) { + if (!CanHighlight(render_value.type)) { + return; + } + JSONHighlighter highlighter(render_value); + highlighter.Process(render_value.text); +} +void BoxRendererImplementation::ComputeRenderWidths(list &collections, idx_t min_width, + idx_t max_width) { + auto column_count = result_types.size(); + idx_t row_count = 0; + for (auto &collection : collections) { + row_count += collection.Count(); + } + + // prepare all rows for rendering + // header / type + BoxRenderRow header_row; + BoxRenderRow type_row; + for (idx_t c = 0; c < column_count; c++) { + header_row.values.emplace_back(ConvertRenderValue(column_names[c]), ResultRenderType::COLUMN_NAME, + ValueRenderAlignment::MIDDLE); + type_row.values.emplace_back(RenderType(result_types[c]), ResultRenderType::COLUMN_TYPE, + ValueRenderAlignment::MIDDLE); + } + render_rows.push_back(std::move(header_row)); + if (config.render_mode == RenderMode::ROWS) { + render_rows.push_back(std::move(type_row)); + } + // prepare the values + bool first_render = true; + bool invert = false; + for (auto &collection : collections) { + if (collection.Count() == 0) { + continue; } - // render the bottom rows - // first render the divider - auto brows = bottom_collection.GetRows(); - for (idx_t k = 0; k < 3; k++) { + if (first_render) { + // add a separator if there are any rows + render_rows.emplace_back(RenderRowType::SEPARATOR); + } else { + // render divider between top and bottom collection + render_rows.emplace_back(RenderRowType::DIVIDER); + } + first_render = false; + vector collection_rows; + for (auto &chunk : collection.Chunks()) { + vector chunk_rows; + chunk_rows.resize(chunk.size()); for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - auto alignment = alignments[c]; - if (alignment == ValueRenderAlignment::MIDDLE || column_idx == SPLIT_COLUMN) { - str = config.DOT; - } else { - // align the dots in the center of the column - ResultRenderType render_mode; - auto top_value = - GetRenderValue(ss, rows, column_idx, top_rows - 1, result_types[column_idx], render_mode); - auto bottom_value = - GetRenderValue(ss, brows, column_idx, bottom_rows - 1, result_types[column_idx], render_mode); - auto top_length = MinValue(widths[c], Utf8Proc::RenderWidth(top_value)); - auto bottom_length = MinValue(widths[c], Utf8Proc::RenderWidth(bottom_value)); - auto dot_length = MinValue(top_length, bottom_length); - if (top_length == 0) { - dot_length = bottom_length; - } else if (bottom_length == 0) { - dot_length = top_length; + auto string_data = FlatVector::GetData(chunk.data[c]); + for (idx_t r = 0; r < chunk.size(); r++) { + string render_value; + ResultRenderType render_type; + ValueRenderAlignment alignment; + LogicalType type; + if (FlatVector::IsNull(chunk.data[c], r)) { + render_value = config.null_value; + render_type = ResultRenderType::NULL_VALUE; + } else { + render_value = ConvertRenderValue(string_data[r].GetString(), result_types[c]); + render_type = ResultRenderType::VALUE; } - if (dot_length > 1) { - auto padding = dot_length - 1; - idx_t left_padding, right_padding; - switch (alignment) { - case ValueRenderAlignment::LEFT: - left_padding = padding / 2; - right_padding = padding - left_padding; + if (config.render_mode == RenderMode::ROWS) { + // in rows mode we select alignment for each column based on the type + alignment = TypeAlignment(result_types[c]); + type = result_types[c]; + } else { + // in columns mode we left-align the header rows, and right-align the values + switch (c) { + case 0: + render_type = ResultRenderType::COLUMN_NAME; + alignment = ValueRenderAlignment::LEFT; break; - case ValueRenderAlignment::RIGHT: - right_padding = padding / 2; - left_padding = padding - right_padding; + case 1: + render_type = ResultRenderType::COLUMN_TYPE; + alignment = ValueRenderAlignment::LEFT; break; default: - throw InternalException("Unrecognized value renderer alignment"); + render_type = ResultRenderType::VALUE; + alignment = ValueRenderAlignment::RIGHT; + // for columns rendering mode - the type for this value is determined by the row index + type = result.Types()[render_rows.size() + r - 2]; + break; + } + } + chunk_rows[r].values.emplace_back(std::move(render_value), render_type, alignment, std::move(type)); + } + } + for (auto &row : chunk_rows) { + collection_rows.push_back(std::move(row)); + } + } + if (config.large_number_rendering == LargeNumberRendering::FOOTER) { + // when rendering the large number footer we align to the middle + for (auto &row : collection_rows) { + for (auto &value : row.values) { + value.alignment = ValueRenderAlignment::MIDDLE; + } + } + // large number footers should be rendered as NULL values + for (auto &row : collection_rows[1].values) { + row.render_mode = ResultRenderType::NULL_VALUE; + } + } + if (invert) { + // the bottom collection is inverted - so flip the rows + std::reverse(collection_rows.begin(), collection_rows.end()); + } + for (auto &row : collection_rows) { + render_rows.push_back(std::move(row)); + } + invert = true; + } + + // now all rows are prepared - figure out the max width of each of the columns + column_widths.resize(column_count, 0); + for (auto &row : render_rows) { + if (row.row_type != RenderRowType::ROW_VALUES) { + continue; + } + D_ASSERT(row.values.size() == column_count); + for (idx_t c = 0; c < column_count; c++) { + auto render_width = Utf8Proc::RenderWidth(row.values[c].text); + if (render_width > column_widths[c]) { + column_widths[c] = render_width; + } + row.values[c].render_width = render_width; + } + } + + // figure out the total length + // we start off with a pipe (|) + total_render_length = 1; + for (idx_t c = 0; c < column_widths.size(); c++) { + // each column has a space at the beginning, and a space plus a pipe (|) at the end + // hence + 3 + total_render_length += column_widths[c] + 3; + } + if (total_render_length < min_width) { + // if there are hidden rows we should always display that + // stretch up the first column until we have space to show the row count + column_widths[0] += min_width - total_render_length; + total_render_length = min_width; + } + // now we need to constrain the length + unordered_set pruned_columns; + bool shortened_columns = false; + if (total_render_length > max_width) { + auto original_widths = column_widths; + // before we remove columns, check if we can just reduce the size of columns + vector max_shorten_amount; + idx_t total_max_shorten_amount = 0; + for (auto &w : column_widths) { + if (w <= config.max_col_width) { + max_shorten_amount.push_back(0); + continue; + } + auto max_diff = w - config.max_col_width; + max_shorten_amount.push_back(max_diff); + total_max_shorten_amount += max_diff; + } + idx_t shorten_amount_required = total_render_length - max_width; + if (total_max_shorten_amount >= shorten_amount_required) { + // we can get below the max width by shortening + // try to shorten everything to the same size + // i.e. if we have one long column and one small column, we would prefer to shorten only the long column + + // map of "shorten amount required -> column index" + map> shorten_amount_required_map; + for (idx_t col_idx = 0; col_idx < max_shorten_amount.size(); col_idx++) { + shorten_amount_required_map[max_shorten_amount[col_idx]].push_back(col_idx); + } + vector actual_shorten_amounts; + actual_shorten_amounts.resize(max_shorten_amount.size()); + + while (shorten_amount_required > 0) { + // find the columns with the longest width + auto entry = shorten_amount_required_map.rbegin(); + auto largest_width = entry->first; + auto &column_list = entry->second; + // shorten these columns to the next-shortest width + // move to the second-largest entry - this is the target entry + entry++; + auto second_largest_width = entry == shorten_amount_required_map.rend() ? 0 : entry->first; + auto max_shorten_width = largest_width - second_largest_width; + D_ASSERT(max_shorten_width > 0); + + auto total_potential_shorten_width = max_shorten_width * column_list.size(); + if (total_potential_shorten_width >= shorten_amount_required) { + // we can reach the shorten amount required just by shortening this set of columns + // shorten the columns equally + idx_t shorten_amount_per_column = shorten_amount_required / column_list.size(); + for (auto &column_idx : column_list) { + actual_shorten_amounts[column_idx] += shorten_amount_per_column; + shorten_amount_required -= shorten_amount_per_column; + } + + // because of truncation, we might still need to shorten columns by a single unit + for (idx_t i = column_list.size(); i > 0 && shorten_amount_required > 0; i--) { + actual_shorten_amounts[column_list[i - 1]]++; + shorten_amount_required--; + } + if (shorten_amount_required != 0) { + throw InternalException("Shorten amount required has tob e zero now"); + } + + // we are now done + break; + } + if (entry == shorten_amount_required_map.rend()) { + throw InternalException( + "ColumnRenderer - we could not reach the shorten amount required but we ran out of columns?"); + } + // we need to shorten all columns to the width of the next-largest column + for (auto &column_idx : column_list) { + actual_shorten_amounts[column_idx] += max_shorten_width; + } + // add all columns to the second-largest list of columns + auto &second_largest_column_list = entry->second; + second_largest_column_list.insert(second_largest_column_list.end(), column_list.begin(), + column_list.end()); + // delete this entry from the shorten map and continue + shorten_amount_required_map.erase(largest_width); + shorten_amount_required -= total_potential_shorten_width; + } + + // now perform the shortening + for (idx_t c = 0; c < actual_shorten_amounts.size(); c++) { + if (actual_shorten_amounts[c] == 0) { + continue; + } + D_ASSERT(actual_shorten_amounts[c] < column_widths[c]); + column_widths[c] -= actual_shorten_amounts[c]; + total_render_length -= actual_shorten_amounts[c]; + shortened_columns = true; + } + } else { + // we cannot get below the max width by shortening + // set everything that is wider than the col width to the max col width + // afterwards - we need to prune columns + for (auto &w : column_widths) { + if (w <= config.max_col_width) { + continue; + } + total_render_length -= w - config.max_col_width; + w = config.max_col_width; + shortened_columns = true; + } + D_ASSERT(total_render_length > max_width); + } + + if (total_render_length > max_width) { + // the total length is still too large + // we need to remove columns! + // first, we add 6 characters to the total length + // this is what we need to add the "..." in the middle + total_render_length += 3 + config.DOTDOTDOT_LENGTH; + // now select columns to prune + // we select columns in zig-zag order starting from the middle + // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc + int64_t offset = 0; + while (total_render_length > max_width) { + auto c = NumericCast(NumericCast(column_count) / 2 + offset); + total_render_length -= column_widths[c] + 3; + pruned_columns.insert(c); + if (offset >= 0) { + offset = -offset - 1; + } else { + offset = -offset; + } + } + + // if we have any space left after truncating columns we can try to increase the size of columns again + idx_t space_left = max_width - total_render_length; + for (idx_t c = 0; c < column_widths.size() && space_left > 0; c++) { + if (pruned_columns.find(c) != pruned_columns.end()) { + // only increase size of visible columns + continue; + } + if (column_widths[c] >= original_widths[c]) { + continue; + } + idx_t increase_amount = MinValue(space_left, original_widths[c] - column_widths[c]); + column_widths[c] += increase_amount; + space_left -= increase_amount; + total_render_length += increase_amount; + } + } + } + + // update the footer with the column counts + UpdateColumnCountFooter(column_count, pruned_columns); + + bool added_split_column = false; + vector new_widths; + vector column_map; + for (idx_t c = 0; c < column_count; c++) { + if (pruned_columns.find(c) == pruned_columns.end()) { + column_map.push_back(c); + new_widths.push_back(column_widths[c]); + } else if (!added_split_column) { + // "..." + column_map.push_back(optional_idx()); + new_widths.push_back(config.DOTDOTDOT_LENGTH); + added_split_column = true; + } + } + column_widths = std::move(new_widths); + column_count = column_widths.size(); + + // update the values based on the columns that were pruned + for (auto &row : render_rows) { + if (row.row_type != RenderRowType::ROW_VALUES) { + continue; + } + vector values; + for (idx_t c = 0; c < column_map.size(); c++) { + auto column_idx = column_map[c]; + if (!column_idx.IsValid()) { + // insert the split column + values.emplace_back(config.DOTDOTDOT, ResultRenderType::LAYOUT, ValueRenderAlignment::MIDDLE); + values.back().render_width = 1; + } else { + values.push_back(std::move(row.values[column_idx.GetIndex()])); + } + } + row.values = std::move(values); + } + // check if we shortened any columns that would be rendered and if we can expand them + // we only expand columns in the ".mode rows", and only if we haven't hidden any columns + if (shortened_columns && config.render_mode == RenderMode::ROWS && render_rows.size() < config.max_rows && + !added_split_column) { + // if we have shortened any columns - try to expand them + // how many rows do we have left to expand before we hit the max row limit? + idx_t max_rows_per_row = MaxValue(1, config.max_rows <= 5 ? 0 : (config.max_rows - 5) / row_count); + // for each row - figure out if we can "expand" the row + for (idx_t r = 0; r < render_rows.size(); r++) { + auto &row = render_rows[r]; + if (row.row_type != RenderRowType::ROW_VALUES) { + continue; + } + bool need_extra_row = r + 1 != render_rows.size() && r != 1; + idx_t min_rows = 2; + if (min_rows > max_rows_per_row) { + // no rows left to expand + continue; + } + // check if this row has truncated columns + vector extra_rows; + for (idx_t c = 0; c < row.values.size(); c++) { + if (CanPrettyPrint(row.values[c].type)) { + idx_t max_rows = max_rows_per_row; + if (need_extra_row) { + max_rows--; + } + PrettyPrintValue(row.values[c], max_rows, column_widths[c]); + if (CanHighlight(row.values[c].type)) { + HighlightValue(row.values[c]); + } + // FIXME: hacky + row.values[c].render_width = column_widths[c] + 1; + } + auto render_width = row.values[c].render_width.GetIndex(); + if (render_width <= column_widths[c]) { + // not shortened - skip + continue; + } + // this value was shortened! try to stretch it out + // first truncate what appears on the first row + idx_t current_row = 0; + idx_t current_pos = 0; + idx_t current_render_width = 0; + auto full_value = row.values[c].text; + auto annotations = row.values[c].annotations; + idx_t annotation_idx = 0; + ResultRenderType active_render_mode = ResultRenderType::VALUE; + row.values[c].annotations.clear(); + row.values[c].text = TruncateValue(full_value, column_widths[c], current_pos, current_render_width); + row.values[c].render_width = current_render_width; + row.values[c].decomposed = true; + // copy over annotations + for (; annotation_idx < annotations.size(); annotation_idx++) { + if (annotations[annotation_idx].start >= current_pos) { + break; + } + row.values[c].annotations.push_back(annotations[annotation_idx]); + } + while (current_pos < full_value.size()) { + if (current_row >= extra_rows.size()) { + if (extra_rows.size() >= max_rows_per_row + 1) { + // we need to add an extra row but there's no space anymore - break + break; + } + // add a new row with empty values + extra_rows.emplace_back(); + for (auto ¤t_val : row.values) { + extra_rows.back().values.emplace_back(string(), current_val.render_mode, + current_val.alignment, current_val.type); } - str = string(left_padding, ' ') + config.DOT + string(right_padding, ' '); + } + bool can_add_extra_row = + current_row + 1 < extra_rows.size() || extra_rows.size() < max_rows_per_row; + auto &extra_row = extra_rows[current_row++]; + idx_t start_pos = current_pos; + // stretch out the remainder on this row + current_render_width = 0; + if (can_add_extra_row) { + // if we can add an extra row after this row truncate it + extra_row.values[c].text = + TruncateValue(full_value, column_widths[c], current_pos, current_render_width); } else { - if (dot_length == 0) { - // everything is empty - alignment = ValueRenderAlignment::MIDDLE; + // if we cannot add an extra row after this just throw all remaining text on this row + extra_row.values[c].text = full_value.substr(current_pos); + current_render_width = Utf8Proc::RenderWidth(extra_row.values[c].text); + current_pos = full_value.size(); + } + extra_row.values[c].render_width = current_render_width; + extra_row.values[c].decomposed = true; + // copy over annotations + if (active_render_mode != ResultRenderType::VALUE) { + extra_row.values[c].annotations.emplace_back(active_render_mode, 0); + } + for (; annotation_idx < annotations.size(); annotation_idx++) { + if (annotations[annotation_idx].start >= current_pos) { + break; } - str = config.DOT; + annotations[annotation_idx].start -= start_pos; + extra_row.values[c].annotations.push_back(annotations[annotation_idx]); + active_render_mode = annotations[annotation_idx].render_mode; } } - RenderValue(ss, str, widths[c], ResultRenderType::LAYOUT, alignment); } - ss << config.VERTICAL; - ss << '\n'; + if (extra_rows.empty()) { + continue; + } + // if we added extra rows we need to add a separator if this is not the last row + if (need_extra_row) { + extra_rows.emplace_back(RenderRowType::SEPARATOR); + } + // add the extra rows at the current position + render_rows.insert(render_rows.begin() + static_cast(r) + 1, extra_rows.begin(), extra_rows.end()); + r += extra_rows.size(); } - // note that the bottom rows are in reverse order - for (idx_t r = 0; r < bottom_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - render_mode = ResultRenderType::LAYOUT; + } + // handle the row dividers + for (idx_t r = 0; r < render_rows.size(); r++) { + auto &row = render_rows[r]; + if (row.row_type != RenderRowType::DIVIDER) { + continue; + } + // generate three new rows + const idx_t divider_row_count = 3; + vector divider_rows; + for (idx_t d = 0; d < divider_row_count; d++) { + divider_rows.emplace_back(RenderRowType::ROW_VALUES); + } + + // find the prev/next rows + idx_t prev_row_idx, next_row_idx; + for (prev_row_idx = r; r > 0; r--) { + if (render_rows[prev_row_idx - 1].row_type == RenderRowType::ROW_VALUES) { + break; + } + } + for (next_row_idx = r + 1; r < render_rows.size(); r++) { + if (render_rows[next_row_idx].row_type == RenderRowType::ROW_VALUES) { + break; + } + } + if (prev_row_idx == 0 || next_row_idx >= render_rows.size()) { + throw InternalException("No prev/next row found"); + } + prev_row_idx--; + auto &prev_row = render_rows[prev_row_idx]; + auto &next_row = render_rows[next_row_idx]; + // now generate the dividers for each of the columns + + for (idx_t c = 0; c < column_count; c++) { + string str; + auto &prev_value = prev_row.values[c]; + auto &next_value = next_row.values[c]; + ValueRenderAlignment alignment = prev_value.alignment; + if (alignment == ValueRenderAlignment::MIDDLE) { + // for middle alignment we don't have to do anything - just push a dot + str = config.DOT; + } else { + // for left / right alignment we want to be in the middle of the prev / next value + auto top_length = MinValue(column_widths[c], Utf8Proc::RenderWidth(prev_value.text)); + auto bottom_length = MinValue(column_widths[c], Utf8Proc::RenderWidth(next_value.text)); + auto dot_length = MinValue(top_length, bottom_length); + if (top_length == 0) { + dot_length = bottom_length; + } else if (bottom_length == 0) { + dot_length = top_length; + } + if (dot_length > 1) { + auto padding = dot_length - 1; + idx_t left_padding, right_padding; + switch (alignment) { + case ValueRenderAlignment::LEFT: + left_padding = padding / 2; + right_padding = padding - left_padding; + break; + case ValueRenderAlignment::RIGHT: + right_padding = padding / 2; + left_padding = padding - right_padding; + break; + default: + throw InternalException("Unrecognized value renderer alignment"); + } + str = string(left_padding, ' ') + config.DOT + string(right_padding, ' '); } else { - str = GetRenderValue(ss, brows, column_idx, bottom_rows - r - 1, result_types[column_idx], - render_mode); + if (dot_length == 0) { + // everything is empty + alignment = ValueRenderAlignment::MIDDLE; + } + str = config.DOT; } - RenderValue(ss, str, widths[c], render_mode, alignments[c]); } - ss << config.VERTICAL; - ss << '\n'; + for (idx_t d = 0; d < divider_row_count; d++) { + divider_rows[d].values.emplace_back(str, ResultRenderType::LAYOUT, alignment); + } + } + // override the divider row with the row values + render_rows[r] = std::move(divider_rows[0]); + // insert the extra divider rows + for (idx_t d = 1; d < divider_row_count; d++) { + render_rows.insert(render_rows.begin() + static_cast(r), std::move(divider_rows[d])); + } + } +} + +void BoxRendererImplementation::RenderLayoutLine(const char *layout, const char *boundary, const char *left_corner, + const char *right_corner) { + // render the top line + ss << left_corner; + idx_t column_index = 0; + for (idx_t k = 0; k < total_render_length - 2; k++) { + if (column_index < column_boundary_positions.size() && k == column_boundary_positions[column_index]) { + ss << boundary; + column_index++; + } else { + ss << layout; } } + ss << right_corner; + ss << '\n'; } -void BoxRenderer::RenderRowCount(string &row_count_str, string &readable_rows_str, string &shown_str, - const string &column_count_str, const vector &boundaries, bool has_hidden_rows, - bool has_hidden_columns, idx_t total_length, idx_t row_count, idx_t column_count, - idx_t minimum_row_length, BaseResultRenderer &ss) { +void BoxRendererImplementation::RenderValues() { + auto column_count = column_widths.size(); + // render the top line + RenderLayoutLine(config.HORIZONTAL, config.TMIDDLE, config.LTCORNER, config.RTCORNER); + + for (idx_t r = 0; r < render_rows.size(); r++) { + auto &row = render_rows[r]; + if (row.row_type == RenderRowType::SEPARATOR) { + // render separator + RenderLayoutLine(config.HORIZONTAL, config.MIDDLE, config.LMIDDLE, config.RMIDDLE); + continue; + } + if (row.row_type == RenderRowType::DIVIDER) { + throw InternalException("Dividers should have been handled before"); + } + // render row values + for (idx_t column_idx = 0; column_idx < column_count; column_idx++) { + auto &render_value = row.values[column_idx]; + auto render_mode = render_value.render_mode; + auto alignment = render_value.alignment; + if (render_mode == ResultRenderType::NULL_VALUE || render_mode == ResultRenderType::VALUE) { + ss.SetValueType(render_value.type); + if (!render_value.decomposed && CanHighlight(render_value.type)) { + HighlightValue(render_value); + } + } + RenderValue(render_value.text, column_widths[column_idx], render_mode, render_value.annotations, alignment, + render_value.render_width); + } + ss << config.VERTICAL; + ss << '\n'; + } +} + +void BoxRendererImplementation::RenderFooter(idx_t row_count, idx_t column_count) { + auto &row_count_str = footer.row_count_str; + auto &column_count_str = footer.column_count_str; + auto &readable_rows_str = footer.readable_rows_str; + auto &shown_str = footer.shown_str; + auto &has_hidden_columns = footer.has_hidden_columns; + auto &has_hidden_rows = footer.has_hidden_rows; // check if we can merge the row_count_str, readable_rows_str and the shown_str auto minimum_length = row_count_str.size() + column_count_str.size() + 6; - bool render_rows_and_columns = total_length >= minimum_length && + bool render_rows_and_columns = total_render_length >= minimum_length && ((has_hidden_columns && row_count > 0) || (row_count >= 10 && column_count > 1)); - bool render_rows = total_length >= minimum_row_length && (row_count == 0 || row_count >= 10); + bool render_rows = total_render_length >= footer.render_length && (row_count == 0 || row_count >= 10); bool render_anything = true; if (!render_rows && !render_rows_and_columns) { render_anything = false; } // render the bottom of the result values, if there are any - if (row_count > 0) { - ss << (render_anything ? config.LMIDDLE : config.LDCORNER); - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < boundaries.size() && k == boundaries[column_index]) { - ss << config.DMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << (render_anything ? config.RMIDDLE : config.RDCORNER); - ss << '\n'; - } + RenderLayoutLine(config.HORIZONTAL, config.DMIDDLE, render_anything ? config.LMIDDLE : config.LDCORNER, + render_anything ? config.RMIDDLE : config.RDCORNER); if (!render_anything) { return; } - idx_t padding = total_length - row_count_str.size() - 4; + idx_t padding = total_render_length - row_count_str.size() - 4; if (render_rows_and_columns) { padding -= column_count_str.size(); } @@ -980,12 +2049,12 @@ void BoxRenderer::RenderRowCount(string &row_count_str, string &readable_rows_st // we still need to render the readable rows/shown strings // check if we can merge the two onto one row idx_t combined_shown_length = readable_rows_str.size() + shown_str.size() + 4; - if (!readable_rows_str.empty() && !shown_str.empty() && combined_shown_length <= total_length) { + if (!readable_rows_str.empty() && !shown_str.empty() && combined_shown_length <= total_render_length) { // we can! merge them ss << config.VERTICAL; ss << " "; ss.Render(ResultRenderType::NULL_VALUE, readable_rows_str); - ss << string(total_length - combined_shown_length, ' '); + ss << string(total_render_length - combined_shown_length, ' '); ss.Render(ResultRenderType::NULL_VALUE, shown_str); ss << " "; ss << config.VERTICAL; @@ -995,148 +2064,44 @@ void BoxRenderer::RenderRowCount(string &row_count_str, string &readable_rows_st } ValueRenderAlignment alignment = render_rows_and_columns ? ValueRenderAlignment::LEFT : ValueRenderAlignment::MIDDLE; + vector annotations; if (!readable_rows_str.empty()) { - RenderValue(ss, "(" + readable_rows_str + ")", total_length - 4, ResultRenderType::NULL_VALUE, alignment); + RenderValue("(" + readable_rows_str + ")", total_render_length - 4, ResultRenderType::NULL_VALUE, + annotations, alignment); ss << config.VERTICAL; ss << '\n'; } if (!shown_str.empty()) { - RenderValue(ss, "(" + shown_str + ")", total_length - 4, ResultRenderType::NULL_VALUE, alignment); + RenderValue("(" + shown_str + ")", total_render_length - 4, ResultRenderType::NULL_VALUE, annotations, + alignment); ss << config.VERTICAL; ss << '\n'; } } // render the bottom line - ss << config.LDCORNER; - for (idx_t k = 0; k < total_length - 2; k++) { - ss << config.HORIZONTAL; - } - ss << config.RDCORNER; - ss << '\n'; + RenderLayoutLine(config.HORIZONTAL, config.HORIZONTAL, config.LDCORNER, config.RDCORNER); } -void BoxRenderer::Render(ClientContext &context, const vector &names, const ColumnDataCollection &result, - BaseResultRenderer &ss) { - if (result.ColumnCount() != names.size()) { - throw InternalException("Error in BoxRenderer::Render - unaligned columns and names"); - } - auto max_width = config.max_width; - if (max_width == 0) { - if (Printer::IsTerminal(OutputStream::STREAM_STDOUT)) { - max_width = Printer::TerminalWidth(); - } else { - max_width = 120; - } - } - // we do not support max widths under 80 - max_width = MaxValue(80, max_width); - - // figure out how many/which rows to render - idx_t row_count = result.Count(); - idx_t rows_to_render = MinValue(row_count, config.max_rows); - if (row_count <= config.max_rows + 3) { - // hiding rows adds 3 extra rows - // so hiding rows makes no sense if we are only slightly over the limit - // if we are 1 row over the limit hiding rows will actually increase the number of lines we display! - // in this case render all the rows - rows_to_render = row_count; - } - idx_t top_rows; - idx_t bottom_rows; - if (rows_to_render == row_count) { - top_rows = row_count; - bottom_rows = 0; - } else { - top_rows = rows_to_render / 2 + (rows_to_render % 2 != 0 ? 1 : 0); - bottom_rows = rows_to_render - top_rows; - } - auto row_count_str = FormatNumber(to_string(row_count)) + " rows"; - bool has_limited_rows = config.limit > 0 && row_count == config.limit; - if (has_limited_rows) { - row_count_str = "? rows"; - } - string readable_rows_str; - if (config.large_number_rendering == LargeNumberRendering::FOOTER && !has_limited_rows) { - readable_rows_str = TryFormatLargeNumber(to_string(row_count)); - if (!readable_rows_str.empty()) { - readable_rows_str += " rows"; - } - } - string shown_str; - bool has_hidden_rows = top_rows < row_count; - if (has_hidden_rows) { - if (has_limited_rows) { - shown_str += ">" + FormatNumber(to_string(config.limit - 1)) + " rows, "; - } - shown_str += FormatNumber(to_string(top_rows + bottom_rows)) + " shown"; - } - auto minimum_row_length = - MaxValue(MaxValue(row_count_str.size(), shown_str.size() + 2), readable_rows_str.size() + 2) + 4; - - // fetch the top and bottom render collections from the result - auto collections = FetchRenderCollections(context, result, top_rows, bottom_rows); - auto column_names = names; - auto result_types = result.Types(); - if (config.render_mode == RenderMode::COLUMNS && rows_to_render > 0) { - collections = PivotCollections(context, std::move(collections), column_names, result_types, row_count); - } - - // for each column, figure out the width - // start off by figuring out the name of the header by looking at the column name and column type - idx_t min_width = has_hidden_rows || row_count == 0 ? minimum_row_length : 0; - vector column_map; - idx_t total_length; - auto widths = - ComputeRenderWidths(column_names, result_types, collections, min_width, max_width, column_map, total_length); - - // render boundaries for the individual columns - vector boundaries; - for (idx_t c = 0; c < widths.size(); c++) { - idx_t render_boundary; - if (c == 0) { - render_boundary = widths[c] + 2; - } else { - render_boundary = boundaries[c - 1] + widths[c] + 3; - } - boundaries.push_back(render_boundary); - } - - // now begin rendering - // first render the header - RenderHeader(column_names, result_types, column_map, widths, boundaries, total_length, row_count > 0, ss); +//===--------------------------------------------------------------------===// +// Box Renderer +//===--------------------------------------------------------------------===// +BoxRenderer::BoxRenderer(BoxRendererConfig config_p) : config(std::move(config_p)) { +} - // render the values, if there are any - RenderValues(collections, column_map, widths, result_types, ss); +string BoxRenderer::ToString(ClientContext &context, const vector &names, const ColumnDataCollection &result) { + StringResultRenderer ss; + Render(context, names, result, ss); + return ss.str(); +} - // render the row count and column count - auto column_count_str = to_string(result.ColumnCount()) + " column"; - if (result.ColumnCount() > 1) { - column_count_str += "s"; - } - bool has_hidden_columns = false; - for (auto entry : column_map) { - if (entry == SPLIT_COLUMN) { - has_hidden_columns = true; - break; - } - } - idx_t column_count = column_map.size(); - if (config.render_mode == RenderMode::COLUMNS) { - if (has_hidden_columns) { - has_hidden_rows = true; - shown_str = to_string(column_count - 3) + " shown"; - } else { - shown_str = string(); - } - } else { - if (has_hidden_columns) { - column_count--; - column_count_str += " (" + to_string(column_count) + " shown)"; - } - } +void BoxRenderer::Print(ClientContext &context, const vector &names, const ColumnDataCollection &result) { + Printer::Print(ToString(context, names, result)); +} - RenderRowCount(row_count_str, readable_rows_str, shown_str, column_count_str, boundaries, has_hidden_rows, - has_hidden_columns, total_length, row_count, column_count, minimum_row_length, ss); +void BoxRenderer::Render(ClientContext &context, const vector &names, const ColumnDataCollection &result, + BaseResultRenderer &ss) { + BoxRendererImplementation implementation(config, context, names, result, ss); + implementation.Render(); } } // namespace duckdb diff --git a/src/duckdb/src/common/compressed_file_system.cpp b/src/duckdb/src/common/compressed_file_system.cpp index 593786204..c8b51d3d4 100644 --- a/src/duckdb/src/common/compressed_file_system.cpp +++ b/src/duckdb/src/common/compressed_file_system.cpp @@ -15,7 +15,19 @@ CompressedFile::~CompressedFile() { try { // stream_wrapper->Close() might throw CompressedFile::Close(); - } catch (...) { // NOLINT - cannot throw in exception + } catch (std::exception &ex) { + if (child_handle) { + // FIXME: Make any log context available here. + ErrorData data(ex); + try { + const auto logger = child_handle->logger; + if (logger) { + DUCKDB_LOG_ERROR(logger, "CompressedFile::~CompressedFile()\t\t" + data.Message()); + } + } catch (...) { // NOLINT + } + } + } catch (...) { // NOLINT } } diff --git a/src/duckdb/src/common/csv_writer.cpp b/src/duckdb/src/common/csv_writer.cpp index bb9ff81d2..8f8992347 100644 --- a/src/duckdb/src/common/csv_writer.cpp +++ b/src/duckdb/src/common/csv_writer.cpp @@ -16,7 +16,7 @@ CSVWriterState::CSVWriterState() } CSVWriterState::CSVWriterState(ClientContext &context, idx_t flush_size_p) - : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context))) { + : flush_size(flush_size_p), stream(make_uniq(Allocator::Get(context), flush_size)) { } CSVWriterState::CSVWriterState(DatabaseInstance &db, idx_t flush_size_p) @@ -71,7 +71,6 @@ CSVWriter::CSVWriter(CSVReaderOptions &options_p, FileSystem &fs, const string & FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW | FileLockType::WRITE_LOCK | compression)), write_stream(*file_writer), should_initialize(true), shared(shared) { - if (!shared) { global_write_state = make_uniq(); } @@ -198,18 +197,6 @@ void CSVWriter::ResetInternal(optional_ptr local_state) { bytes_written = 0; } -unique_ptr CSVWriter::InitializeLocalWriteState(ClientContext &context, idx_t flush_size) { - auto res = make_uniq(context, flush_size); - res->stream = make_uniq(); - return res; -} - -unique_ptr CSVWriter::InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size) { - auto res = make_uniq(db, flush_size); - res->stream = make_uniq(); - return res; -} - idx_t CSVWriter::BytesWritten() { if (shared) { lock_guard flock(lock); diff --git a/src/duckdb/src/common/encryption_functions.cpp b/src/duckdb/src/common/encryption_functions.cpp index 1ecf1abeb..262ff3c14 100644 --- a/src/duckdb/src/common/encryption_functions.cpp +++ b/src/duckdb/src/common/encryption_functions.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/encryption_key_manager.hpp" #include "duckdb/common/encryption_functions.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/main/attached_database.hpp" #include "mbedtls_wrapper.hpp" #include "duckdb/storage/storage_manager.hpp" @@ -30,6 +31,22 @@ idx_t EncryptionNonce::size() const { return MainHeader::AES_NONCE_LEN; } +constexpr uint32_t AdditionalAuthenticatedData::INITIAL_AAD_CAPACITY; + +AdditionalAuthenticatedData::~AdditionalAuthenticatedData() = default; + +data_ptr_t AdditionalAuthenticatedData::data() const { + return additional_authenticated_data->GetData(); +} + +idx_t AdditionalAuthenticatedData::size() const { + return additional_authenticated_data->GetPosition(); +} + +void AdditionalAuthenticatedData::WriteStringData(const std::string &val) const { + additional_authenticated_data->WriteData(reinterpret_cast(val.data()), val.size()); +} + EncryptionEngine::EncryptionEngine() { } diff --git a/src/duckdb/src/common/encryption_key_manager.cpp b/src/duckdb/src/common/encryption_key_manager.cpp index 482c4a006..b0044a5f0 100644 --- a/src/duckdb/src/common/encryption_key_manager.cpp +++ b/src/duckdb/src/common/encryption_key_manager.cpp @@ -19,7 +19,8 @@ EncryptionKey::EncryptionKey(data_ptr_t encryption_key_p) { D_ASSERT(memcmp(key, encryption_key_p, MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH) == 0); // zero out the encryption key in memory - memset(encryption_key_p, 0, MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(encryption_key_p, + MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH); LockEncryptionKey(key); } @@ -31,15 +32,19 @@ EncryptionKey::~EncryptionKey() { void EncryptionKey::LockEncryptionKey(data_ptr_t key, idx_t key_len) { #if defined(_WIN32) VirtualLock(key, key_len); +#elif defined(__MVS__) + __mlockall(_BPX_NONSWAP); #else mlock(key, key_len); #endif } void EncryptionKey::UnlockEncryptionKey(data_ptr_t key, idx_t key_len) { - memset(key, 0, key_len); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(key, key_len); #if defined(_WIN32) VirtualUnlock(key, key_len); +#elif defined(__MVS__) + __mlockall(_BPX_SWAP); #else munlock(key, key_len); #endif @@ -64,27 +69,32 @@ EncryptionKeyManager &EncryptionKeyManager::Get(DatabaseInstance &db) { string EncryptionKeyManager::GenerateRandomKeyID() { uint8_t key_id[KEY_ID_BYTES]; - duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(key_id, KEY_ID_BYTES); + RandomEngine engine; + engine.RandomData(key_id, KEY_ID_BYTES); string key_id_str(reinterpret_cast(key_id), KEY_ID_BYTES); return key_id_str; } void EncryptionKeyManager::AddKey(const string &key_name, data_ptr_t key) { + lock_guard guard(lock); derived_keys.emplace(key_name, EncryptionKey(key)); // Zero-out the encryption key - std::memset(key, 0, DERIVED_KEY_LENGTH); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(key, DERIVED_KEY_LENGTH); } bool EncryptionKeyManager::HasKey(const string &key_name) const { + lock_guard guard(lock); return derived_keys.find(key_name) != derived_keys.end(); } const_data_ptr_t EncryptionKeyManager::GetKey(const string &key_name) const { D_ASSERT(HasKey(key_name)); + lock_guard guard(lock); return derived_keys.at(key_name).GetPtr(); } void EncryptionKeyManager::DeleteKey(const string &key_name) { + lock_guard guard(lock); derived_keys.erase(key_name); } @@ -107,7 +117,7 @@ string EncryptionKeyManager::Base64Decode(const string &key) { auto output = duckdb::unique_ptr(new unsigned char[result_size]); Blob::FromBase64(key, output.get(), result_size); string decoded_key(reinterpret_cast(output.get()), result_size); - memset(output.get(), 0, result_size); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(output.get(), result_size); return decoded_key; } @@ -124,10 +134,9 @@ void EncryptionKeyManager::DeriveKey(string &user_key, data_ptr_t salt, data_ptr KeyDerivationFunctionSHA256(reinterpret_cast(decoded_key.data()), decoded_key.size(), salt, derived_key); - - // wipe the original and decoded key - std::fill(user_key.begin(), user_key.end(), 0); - std::fill(decoded_key.begin(), decoded_key.end(), 0); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&user_key[0]), user_key.size()); + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&decoded_key[0]), + decoded_key.size()); user_key.clear(); decoded_key.clear(); } diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 324ba7004..580a2e7d7 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -82,27 +82,30 @@ #include "duckdb/common/multi_file/multi_file_options.hpp" #include "duckdb/common/operator/decimal_cast_operators.hpp" #include "duckdb/common/printer.hpp" -#include "duckdb/common/sort/partition_state.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/column/column_data_scan_states.hpp" #include "duckdb/common/types/column/partitioned_column_data.hpp" #include "duckdb/common/types/conflict_manager.hpp" #include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types/hyperloglog.hpp" #include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/common/types/row/partitioned_tuple_data.hpp" #include "duckdb/common/types/row/tuple_data_states.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/variant.hpp" +#include "duckdb/common/types/variant_value.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/types/vector_buffer.hpp" #include "duckdb/execution/index/art/art.hpp" #include "duckdb/execution/index/art/art_scanner.hpp" #include "duckdb/execution/index/art/node.hpp" #include "duckdb/execution/index/bound_index.hpp" +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/execution/operator/csv_scanner/csv_option.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" #include "duckdb/execution/reservoir_sample.hpp" #include "duckdb/function/aggregate_state.hpp" #include "duckdb/function/compression_function.hpp" @@ -121,15 +124,18 @@ #include "duckdb/logging/log_storage.hpp" #include "duckdb/logging/logging.hpp" #include "duckdb/main/appender.hpp" +#include "duckdb/main/attached_database.hpp" #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/error_manager.hpp" #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/main/extension_install_info.hpp" +#include "duckdb/main/query_parameters.hpp" #include "duckdb/main/query_profiler.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/main/setting_info.hpp" +#include "duckdb/parallel/async_result.hpp" #include "duckdb/parallel/interrupt.hpp" #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/task.hpp" @@ -162,9 +168,11 @@ #include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/caching_file_system_wrapper.hpp" #include "duckdb/storage/compression/bitpacking.hpp" #include "duckdb/storage/magic_bytes.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/table_index_list.hpp" @@ -631,6 +639,45 @@ ArrowVariableSizeType EnumUtil::FromString(const char *va return static_cast(StringUtil::StringToEnum(GetArrowVariableSizeTypeValues(), 4, "ArrowVariableSizeType", value)); } +const StringUtil::EnumStringLiteral *GetAsyncResultTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(AsyncResultType::INVALID), "INVALID" }, + { static_cast(AsyncResultType::IMPLICIT), "IMPLICIT" }, + { static_cast(AsyncResultType::HAVE_MORE_OUTPUT), "HAVE_MORE_OUTPUT" }, + { static_cast(AsyncResultType::FINISHED), "FINISHED" }, + { static_cast(AsyncResultType::BLOCKED), "BLOCKED" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(AsyncResultType value) { + return StringUtil::EnumToString(GetAsyncResultTypeValues(), 5, "AsyncResultType", static_cast(value)); +} + +template<> +AsyncResultType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetAsyncResultTypeValues(), 5, "AsyncResultType", value)); +} + +const StringUtil::EnumStringLiteral *GetAsyncResultsExecutionModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(AsyncResultsExecutionMode::SYNCHRONOUS), "SYNCHRONOUS" }, + { static_cast(AsyncResultsExecutionMode::TASK_EXECUTOR), "TASK_EXECUTOR" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(AsyncResultsExecutionMode value) { + return StringUtil::EnumToString(GetAsyncResultsExecutionModeValues(), 2, "AsyncResultsExecutionMode", static_cast(value)); +} + +template<> +AsyncResultsExecutionMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetAsyncResultsExecutionModeValues(), 2, "AsyncResultsExecutionMode", value)); +} + const StringUtil::EnumStringLiteral *GetBinderTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(BinderType::REGULAR_BINDER), "REGULAR_BINDER" }, @@ -727,6 +774,24 @@ BlockState EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetBlockStateValues(), 2, "BlockState", value)); } +const StringUtil::EnumStringLiteral *GetBufferedIndexReplayValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(BufferedIndexReplay::INSERT_ENTRY), "INSERT_ENTRY" }, + { static_cast(BufferedIndexReplay::DEL_ENTRY), "DEL_ENTRY" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(BufferedIndexReplay value) { + return StringUtil::EnumToString(GetBufferedIndexReplayValues(), 2, "BufferedIndexReplay", static_cast(value)); +} + +template<> +BufferedIndexReplay EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetBufferedIndexReplayValues(), 2, "BufferedIndexReplay", value)); +} + const StringUtil::EnumStringLiteral *GetCAPIResultSetTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_NONE), "CAPI_RESULT_TYPE_NONE" }, @@ -801,6 +866,24 @@ CTEMaterialize EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetCTEMaterializeValues(), 3, "CTEMaterialize", value)); } +const StringUtil::EnumStringLiteral *GetCachingModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(CachingMode::ALWAYS_CACHE), "ALWAYS_CACHE" }, + { static_cast(CachingMode::CACHE_REMOTE_ONLY), "CACHE_REMOTE_ONLY" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(CachingMode value) { + return StringUtil::EnumToString(GetCachingModeValues(), 2, "CachingMode", static_cast(value)); +} + +template<> +CachingMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetCachingModeValues(), 2, "CachingMode", value)); +} + const StringUtil::EnumStringLiteral *GetCatalogLookupBehaviorValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(CatalogLookupBehavior::STANDARD), "STANDARD" }, @@ -864,19 +947,22 @@ const StringUtil::EnumStringLiteral *GetCheckpointAbortValues() { { static_cast(CheckpointAbort::NO_ABORT), "NONE" }, { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE), "BEFORE_TRUNCATE" }, { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER), "BEFORE_HEADER" }, - { static_cast(CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE), "AFTER_FREE_LIST_WRITE" } + { static_cast(CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE), "AFTER_FREE_LIST_WRITE" }, + { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_WAL_FINISH), "BEFORE_WAL_FINISH" }, + { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_MOVING_RECOVERY), "BEFORE_MOVING_RECOVERY" }, + { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_DELETING_CHECKPOINT_WAL), "BEFORE_DELETING_CHECKPOINT_WAL" } }; return values; } template<> const char* EnumUtil::ToChars(CheckpointAbort value) { - return StringUtil::EnumToString(GetCheckpointAbortValues(), 4, "CheckpointAbort", static_cast(value)); + return StringUtil::EnumToString(GetCheckpointAbortValues(), 7, "CheckpointAbort", static_cast(value)); } template<> CheckpointAbort EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCheckpointAbortValues(), 4, "CheckpointAbort", value)); + return static_cast(StringUtil::StringToEnum(GetCheckpointAbortValues(), 7, "CheckpointAbort", value)); } const StringUtil::EnumStringLiteral *GetChunkInfoTypeValues() { @@ -1464,19 +1550,20 @@ const StringUtil::EnumStringLiteral *GetExplainFormatValues() { { static_cast(ExplainFormat::JSON), "JSON" }, { static_cast(ExplainFormat::HTML), "HTML" }, { static_cast(ExplainFormat::GRAPHVIZ), "GRAPHVIZ" }, - { static_cast(ExplainFormat::YAML), "YAML" } + { static_cast(ExplainFormat::YAML), "YAML" }, + { static_cast(ExplainFormat::MERMAID), "MERMAID" } }; return values; } template<> const char* EnumUtil::ToChars(ExplainFormat value) { - return StringUtil::EnumToString(GetExplainFormatValues(), 6, "ExplainFormat", static_cast(value)); + return StringUtil::EnumToString(GetExplainFormatValues(), 7, "ExplainFormat", static_cast(value)); } template<> ExplainFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExplainFormatValues(), 6, "ExplainFormat", value)); + return static_cast(StringUtil::StringToEnum(GetExplainFormatValues(), 7, "ExplainFormat", value)); } const StringUtil::EnumStringLiteral *GetExplainOutputTypeValues() { @@ -1795,19 +1882,20 @@ const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { { static_cast(ExtraTypeInfoType::ARRAY_TYPE_INFO), "ARRAY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" } + { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" }, + { static_cast(ExtraTypeInfoType::GEO_TYPE_INFO), "GEO_TYPE_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", static_cast(value)); + return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", static_cast(value)); } template<> ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", value)); } const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { @@ -2059,6 +2147,30 @@ GateStatus EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetGateStatusValues(), 2, "GateStatus", value)); } +const StringUtil::EnumStringLiteral *GetGeometryTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(GeometryType::INVALID), "INVALID" }, + { static_cast(GeometryType::POINT), "POINT" }, + { static_cast(GeometryType::LINESTRING), "LINESTRING" }, + { static_cast(GeometryType::POLYGON), "POLYGON" }, + { static_cast(GeometryType::MULTIPOINT), "MULTIPOINT" }, + { static_cast(GeometryType::MULTILINESTRING), "MULTILINESTRING" }, + { static_cast(GeometryType::MULTIPOLYGON), "MULTIPOLYGON" }, + { static_cast(GeometryType::GEOMETRYCOLLECTION), "GEOMETRYCOLLECTION" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(GeometryType value) { + return StringUtil::EnumToString(GetGeometryTypeValues(), 8, "GeometryType", static_cast(value)); +} + +template<> +GeometryType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetGeometryTypeValues(), 8, "GeometryType", value)); +} + const StringUtil::EnumStringLiteral *GetHLLStorageTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(HLLStorageType::HLL_V1), "HLL_V1" }, @@ -2424,7 +2536,7 @@ const StringUtil::EnumStringLiteral *GetLogLevelValues() { { static_cast(LogLevel::LOG_TRACE), "TRACE" }, { static_cast(LogLevel::LOG_DEBUG), "DEBUG" }, { static_cast(LogLevel::LOG_INFO), "INFO" }, - { static_cast(LogLevel::LOG_WARN), "WARN" }, + { static_cast(LogLevel::LOG_WARNING), "WARNING" }, { static_cast(LogLevel::LOG_ERROR), "ERROR" }, { static_cast(LogLevel::LOG_FATAL), "FATAL" } }; @@ -2599,6 +2711,7 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::POINTER), "POINTER" }, { static_cast(LogicalTypeId::VALIDITY), "VALIDITY" }, { static_cast(LogicalTypeId::UUID), "UUID" }, + { static_cast(LogicalTypeId::GEOMETRY), "GEOMETRY" }, { static_cast(LogicalTypeId::STRUCT), "STRUCT" }, { static_cast(LogicalTypeId::LIST), "LIST" }, { static_cast(LogicalTypeId::MAP), "MAP" }, @@ -2615,12 +2728,12 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { template<> const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 50, "LogicalTypeId", static_cast(value)); + return StringUtil::EnumToString(GetLogicalTypeIdValues(), 51, "LogicalTypeId", static_cast(value)); } template<> LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 50, "LogicalTypeId", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 51, "LogicalTypeId", value)); } const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { @@ -2697,19 +2810,20 @@ const StringUtil::EnumStringLiteral *GetMemoryTagValues() { { static_cast(MemoryTag::ALLOCATOR), "ALLOCATOR" }, { static_cast(MemoryTag::EXTENSION), "EXTENSION" }, { static_cast(MemoryTag::TRANSACTION), "TRANSACTION" }, - { static_cast(MemoryTag::EXTERNAL_FILE_CACHE), "EXTERNAL_FILE_CACHE" } + { static_cast(MemoryTag::EXTERNAL_FILE_CACHE), "EXTERNAL_FILE_CACHE" }, + { static_cast(MemoryTag::WINDOW), "WINDOW" } }; return values; } template<> const char* EnumUtil::ToChars(MemoryTag value) { - return StringUtil::EnumToString(GetMemoryTagValues(), 14, "MemoryTag", static_cast(value)); + return StringUtil::EnumToString(GetMemoryTagValues(), 15, "MemoryTag", static_cast(value)); } template<> MemoryTag EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMemoryTagValues(), 14, "MemoryTag", value)); + return static_cast(StringUtil::StringToEnum(GetMemoryTagValues(), 15, "MemoryTag", value)); } const StringUtil::EnumStringLiteral *GetMergeActionConditionValues() { @@ -2770,74 +2884,111 @@ MetaPipelineType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetMetaPipelineTypeValues(), 2, "MetaPipelineType", value)); } -const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, - { static_cast(MetricsType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, - { static_cast(MetricsType::CPU_TIME), "CPU_TIME" }, - { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, - { static_cast(MetricsType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, - { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, - { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, - { static_cast(MetricsType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, - { static_cast(MetricsType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, - { static_cast(MetricsType::OPERATOR_TIMING), "OPERATOR_TIMING" }, - { static_cast(MetricsType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, - { static_cast(MetricsType::LATENCY), "LATENCY" }, - { static_cast(MetricsType::ROWS_RETURNED), "ROWS_RETURNED" }, - { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, - { static_cast(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY), "SYSTEM_PEAK_BUFFER_MEMORY" }, - { static_cast(MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE), "SYSTEM_PEAK_TEMP_DIR_SIZE" }, - { static_cast(MetricsType::TOTAL_BYTES_READ), "TOTAL_BYTES_READ" }, - { static_cast(MetricsType::TOTAL_BYTES_WRITTEN), "TOTAL_BYTES_WRITTEN" }, - { static_cast(MetricsType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, - { static_cast(MetricsType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, - { static_cast(MetricsType::PLANNER), "PLANNER" }, - { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, - { static_cast(MetricsType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, - { static_cast(MetricsType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, - { static_cast(MetricsType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP), "OPTIMIZER_EMPTY_RESULT_PULLUP" }, - { static_cast(MetricsType::OPTIMIZER_CTE_FILTER_PUSHER), "OPTIMIZER_CTE_FILTER_PUSHER" }, - { static_cast(MetricsType::OPTIMIZER_REGEX_RANGE), "OPTIMIZER_REGEX_RANGE" }, - { static_cast(MetricsType::OPTIMIZER_IN_CLAUSE), "OPTIMIZER_IN_CLAUSE" }, - { static_cast(MetricsType::OPTIMIZER_JOIN_ORDER), "OPTIMIZER_JOIN_ORDER" }, - { static_cast(MetricsType::OPTIMIZER_DELIMINATOR), "OPTIMIZER_DELIMINATOR" }, - { static_cast(MetricsType::OPTIMIZER_UNNEST_REWRITER), "OPTIMIZER_UNNEST_REWRITER" }, - { static_cast(MetricsType::OPTIMIZER_UNUSED_COLUMNS), "OPTIMIZER_UNUSED_COLUMNS" }, - { static_cast(MetricsType::OPTIMIZER_STATISTICS_PROPAGATION), "OPTIMIZER_STATISTICS_PROPAGATION" }, - { static_cast(MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS), "OPTIMIZER_COMMON_SUBEXPRESSIONS" }, - { static_cast(MetricsType::OPTIMIZER_COMMON_AGGREGATE), "OPTIMIZER_COMMON_AGGREGATE" }, - { static_cast(MetricsType::OPTIMIZER_COLUMN_LIFETIME), "OPTIMIZER_COLUMN_LIFETIME" }, - { static_cast(MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, - { static_cast(MetricsType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, - { static_cast(MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, - { static_cast(MetricsType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, - { static_cast(MetricsType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, - { static_cast(MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN), "OPTIMIZER_SAMPLING_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN), "OPTIMIZER_JOIN_FILTER_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_EXTENSION), "OPTIMIZER_EXTENSION" }, - { static_cast(MetricsType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, - { static_cast(MetricsType::OPTIMIZER_SUM_REWRITER), "OPTIMIZER_SUM_REWRITER" }, - { static_cast(MetricsType::OPTIMIZER_LATE_MATERIALIZATION), "OPTIMIZER_LATE_MATERIALIZATION" }, - { static_cast(MetricsType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 54, "MetricsType", static_cast(value)); -} - -template<> -MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 54, "MetricsType", value)); +const StringUtil::EnumStringLiteral *GetMetricGroupValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(MetricGroup::ALL), "ALL" }, + { static_cast(MetricGroup::CORE), "CORE" }, + { static_cast(MetricGroup::DEFAULT), "DEFAULT" }, + { static_cast(MetricGroup::EXECUTION), "EXECUTION" }, + { static_cast(MetricGroup::FILE), "FILE" }, + { static_cast(MetricGroup::OPERATOR), "OPERATOR" }, + { static_cast(MetricGroup::OPTIMIZER), "OPTIMIZER" }, + { static_cast(MetricGroup::PHASE_TIMING), "PHASE_TIMING" }, + { static_cast(MetricGroup::INVALID), "INVALID" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(MetricGroup value) { + return StringUtil::EnumToString(GetMetricGroupValues(), 9, "MetricGroup", static_cast(value)); +} + +template<> +MetricGroup EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetMetricGroupValues(), 9, "MetricGroup", value)); +} + +const StringUtil::EnumStringLiteral *GetMetricTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(MetricType::CPU_TIME), "CPU_TIME" }, + { static_cast(MetricType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, + { static_cast(MetricType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, + { static_cast(MetricType::EXTRA_INFO), "EXTRA_INFO" }, + { static_cast(MetricType::LATENCY), "LATENCY" }, + { static_cast(MetricType::QUERY_NAME), "QUERY_NAME" }, + { static_cast(MetricType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, + { static_cast(MetricType::ROWS_RETURNED), "ROWS_RETURNED" }, + { static_cast(MetricType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, + { static_cast(MetricType::SYSTEM_PEAK_BUFFER_MEMORY), "SYSTEM_PEAK_BUFFER_MEMORY" }, + { static_cast(MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE), "SYSTEM_PEAK_TEMP_DIR_SIZE" }, + { static_cast(MetricType::TOTAL_MEMORY_ALLOCATED), "TOTAL_MEMORY_ALLOCATED" }, + { static_cast(MetricType::ATTACH_LOAD_STORAGE_LATENCY), "ATTACH_LOAD_STORAGE_LATENCY" }, + { static_cast(MetricType::ATTACH_REPLAY_WAL_LATENCY), "ATTACH_REPLAY_WAL_LATENCY" }, + { static_cast(MetricType::CHECKPOINT_LATENCY), "CHECKPOINT_LATENCY" }, + { static_cast(MetricType::COMMIT_LOCAL_STORAGE_LATENCY), "COMMIT_LOCAL_STORAGE_LATENCY" }, + { static_cast(MetricType::TOTAL_BYTES_READ), "TOTAL_BYTES_READ" }, + { static_cast(MetricType::TOTAL_BYTES_WRITTEN), "TOTAL_BYTES_WRITTEN" }, + { static_cast(MetricType::WAITING_TO_ATTACH_LATENCY), "WAITING_TO_ATTACH_LATENCY" }, + { static_cast(MetricType::WAL_REPLAY_ENTRY_COUNT), "WAL_REPLAY_ENTRY_COUNT" }, + { static_cast(MetricType::WRITE_TO_WAL_LATENCY), "WRITE_TO_WAL_LATENCY" }, + { static_cast(MetricType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, + { static_cast(MetricType::OPERATOR_NAME), "OPERATOR_NAME" }, + { static_cast(MetricType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, + { static_cast(MetricType::OPERATOR_TIMING), "OPERATOR_TIMING" }, + { static_cast(MetricType::OPERATOR_TYPE), "OPERATOR_TYPE" }, + { static_cast(MetricType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, + { static_cast(MetricType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, + { static_cast(MetricType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, + { static_cast(MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP), "OPTIMIZER_EMPTY_RESULT_PULLUP" }, + { static_cast(MetricType::OPTIMIZER_CTE_FILTER_PUSHER), "OPTIMIZER_CTE_FILTER_PUSHER" }, + { static_cast(MetricType::OPTIMIZER_REGEX_RANGE), "OPTIMIZER_REGEX_RANGE" }, + { static_cast(MetricType::OPTIMIZER_IN_CLAUSE), "OPTIMIZER_IN_CLAUSE" }, + { static_cast(MetricType::OPTIMIZER_JOIN_ORDER), "OPTIMIZER_JOIN_ORDER" }, + { static_cast(MetricType::OPTIMIZER_DELIMINATOR), "OPTIMIZER_DELIMINATOR" }, + { static_cast(MetricType::OPTIMIZER_UNNEST_REWRITER), "OPTIMIZER_UNNEST_REWRITER" }, + { static_cast(MetricType::OPTIMIZER_UNUSED_COLUMNS), "OPTIMIZER_UNUSED_COLUMNS" }, + { static_cast(MetricType::OPTIMIZER_STATISTICS_PROPAGATION), "OPTIMIZER_STATISTICS_PROPAGATION" }, + { static_cast(MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS), "OPTIMIZER_COMMON_SUBEXPRESSIONS" }, + { static_cast(MetricType::OPTIMIZER_COMMON_AGGREGATE), "OPTIMIZER_COMMON_AGGREGATE" }, + { static_cast(MetricType::OPTIMIZER_COLUMN_LIFETIME), "OPTIMIZER_COLUMN_LIFETIME" }, + { static_cast(MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, + { static_cast(MetricType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, + { static_cast(MetricType::OPTIMIZER_ROW_GROUP_PRUNER), "OPTIMIZER_ROW_GROUP_PRUNER" }, + { static_cast(MetricType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, + { static_cast(MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION), "OPTIMIZER_TOP_N_WINDOW_ELIMINATION" }, + { static_cast(MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, + { static_cast(MetricType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, + { static_cast(MetricType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, + { static_cast(MetricType::OPTIMIZER_SAMPLING_PUSHDOWN), "OPTIMIZER_SAMPLING_PUSHDOWN" }, + { static_cast(MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN), "OPTIMIZER_JOIN_FILTER_PUSHDOWN" }, + { static_cast(MetricType::OPTIMIZER_EXTENSION), "OPTIMIZER_EXTENSION" }, + { static_cast(MetricType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, + { static_cast(MetricType::OPTIMIZER_SUM_REWRITER), "OPTIMIZER_SUM_REWRITER" }, + { static_cast(MetricType::OPTIMIZER_LATE_MATERIALIZATION), "OPTIMIZER_LATE_MATERIALIZATION" }, + { static_cast(MetricType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" }, + { static_cast(MetricType::OPTIMIZER_COMMON_SUBPLAN), "OPTIMIZER_COMMON_SUBPLAN" }, + { static_cast(MetricType::OPTIMIZER_JOIN_ELIMINATION), "OPTIMIZER_JOIN_ELIMINATION" }, + { static_cast(MetricType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, + { static_cast(MetricType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, + { static_cast(MetricType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, + { static_cast(MetricType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, + { static_cast(MetricType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, + { static_cast(MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, + { static_cast(MetricType::PLANNER), "PLANNER" }, + { static_cast(MetricType::PLANNER_BINDING), "PLANNER_BINDING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(MetricType value) { + return StringUtil::EnumToString(GetMetricTypeValues(), 66, "MetricType", static_cast(value)); +} + +template<> +MetricType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetMetricTypeValues(), 66, "MetricType", value)); } const StringUtil::EnumStringLiteral *GetMultiFileColumnMappingModeValues() { @@ -3059,7 +3210,9 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::COLUMN_LIFETIME), "COLUMN_LIFETIME" }, { static_cast(OptimizerType::BUILD_SIDE_PROBE_SIDE), "BUILD_SIDE_PROBE_SIDE" }, { static_cast(OptimizerType::LIMIT_PUSHDOWN), "LIMIT_PUSHDOWN" }, + { static_cast(OptimizerType::ROW_GROUP_PRUNER), "ROW_GROUP_PRUNER" }, { static_cast(OptimizerType::TOP_N), "TOP_N" }, + { static_cast(OptimizerType::TOP_N_WINDOW_ELIMINATION), "TOP_N_WINDOW_ELIMINATION" }, { static_cast(OptimizerType::COMPRESSED_MATERIALIZATION), "COMPRESSED_MATERIALIZATION" }, { static_cast(OptimizerType::DUPLICATE_GROUPS), "DUPLICATE_GROUPS" }, { static_cast(OptimizerType::REORDER_FILTER), "REORDER_FILTER" }, @@ -3069,19 +3222,21 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::MATERIALIZED_CTE), "MATERIALIZED_CTE" }, { static_cast(OptimizerType::SUM_REWRITER), "SUM_REWRITER" }, { static_cast(OptimizerType::LATE_MATERIALIZATION), "LATE_MATERIALIZATION" }, - { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" } + { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" }, + { static_cast(OptimizerType::COMMON_SUBPLAN), "COMMON_SUBPLAN" }, + { static_cast(OptimizerType::JOIN_ELIMINATION), "JOIN_ELIMINATION" } }; return values; } template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 29, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 33, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 29, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 33, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { @@ -3237,28 +3392,6 @@ ParserExtensionResultType EnumUtil::FromString(const return static_cast(StringUtil::StringToEnum(GetParserExtensionResultTypeValues(), 3, "ParserExtensionResultType", value)); } -const StringUtil::EnumStringLiteral *GetPartitionSortStageValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionSortStage::INIT), "INIT" }, - { static_cast(PartitionSortStage::SCAN), "SCAN" }, - { static_cast(PartitionSortStage::PREPARE), "PREPARE" }, - { static_cast(PartitionSortStage::MERGE), "MERGE" }, - { static_cast(PartitionSortStage::SORTED), "SORTED" }, - { static_cast(PartitionSortStage::FINISHED), "FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value) { - return StringUtil::EnumToString(GetPartitionSortStageValues(), 6, "PartitionSortStage", static_cast(value)); -} - -template<> -PartitionSortStage EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionSortStageValues(), 6, "PartitionSortStage", value)); -} - const StringUtil::EnumStringLiteral *GetPartitionedColumnDataTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(PartitionedColumnDataType::INVALID), "INVALID" }, @@ -3416,6 +3549,26 @@ PhysicalOperatorType EnumUtil::FromString(const char *valu return static_cast(StringUtil::StringToEnum(GetPhysicalOperatorTypeValues(), 82, "PhysicalOperatorType", value)); } +const StringUtil::EnumStringLiteral *GetPhysicalTableScanExecutionStrategyValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(PhysicalTableScanExecutionStrategy::DEFAULT), "DEFAULT" }, + { static_cast(PhysicalTableScanExecutionStrategy::TASK_EXECUTOR), "TASK_EXECUTOR" }, + { static_cast(PhysicalTableScanExecutionStrategy::SYNCHRONOUS), "SYNCHRONOUS" }, + { static_cast(PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS), "TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(PhysicalTableScanExecutionStrategy value) { + return StringUtil::EnumToString(GetPhysicalTableScanExecutionStrategyValues(), 4, "PhysicalTableScanExecutionStrategy", static_cast(value)); +} + +template<> +PhysicalTableScanExecutionStrategy EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetPhysicalTableScanExecutionStrategyValues(), 4, "PhysicalTableScanExecutionStrategy", value)); +} + const StringUtil::EnumStringLiteral *GetPhysicalTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(PhysicalType::BOOL), "BOOL" }, @@ -3535,19 +3688,20 @@ const StringUtil::EnumStringLiteral *GetProfilerPrintFormatValues() { { static_cast(ProfilerPrintFormat::QUERY_TREE_OPTIMIZER), "QUERY_TREE_OPTIMIZER" }, { static_cast(ProfilerPrintFormat::NO_OUTPUT), "NO_OUTPUT" }, { static_cast(ProfilerPrintFormat::HTML), "HTML" }, - { static_cast(ProfilerPrintFormat::GRAPHVIZ), "GRAPHVIZ" } + { static_cast(ProfilerPrintFormat::GRAPHVIZ), "GRAPHVIZ" }, + { static_cast(ProfilerPrintFormat::MERMAID), "MERMAID" } }; return values; } template<> const char* EnumUtil::ToChars(ProfilerPrintFormat value) { - return StringUtil::EnumToString(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", static_cast(value)); + return StringUtil::EnumToString(GetProfilerPrintFormatValues(), 7, "ProfilerPrintFormat", static_cast(value)); } template<> ProfilerPrintFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", value)); + return static_cast(StringUtil::StringToEnum(GetProfilerPrintFormatValues(), 7, "ProfilerPrintFormat", value)); } const StringUtil::EnumStringLiteral *GetProfilingCoverageValues() { @@ -3595,19 +3749,56 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { { static_cast(QueryNodeType::SET_OPERATION_NODE), "SET_OPERATION_NODE" }, { static_cast(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" }, { static_cast(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" }, - { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" } + { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" }, + { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" } }; return values; } template<> const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 5, "QueryNodeType", static_cast(value)); + return StringUtil::EnumToString(GetQueryNodeTypeValues(), 6, "QueryNodeType", static_cast(value)); } template<> QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 5, "QueryNodeType", value)); + return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 6, "QueryNodeType", value)); +} + +const StringUtil::EnumStringLiteral *GetQueryResultMemoryTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(QueryResultMemoryType::IN_MEMORY), "IN_MEMORY" }, + { static_cast(QueryResultMemoryType::BUFFER_MANAGED), "BUFFER_MANAGED" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(QueryResultMemoryType value) { + return StringUtil::EnumToString(GetQueryResultMemoryTypeValues(), 2, "QueryResultMemoryType", static_cast(value)); +} + +template<> +QueryResultMemoryType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetQueryResultMemoryTypeValues(), 2, "QueryResultMemoryType", value)); +} + +const StringUtil::EnumStringLiteral *GetQueryResultOutputTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(QueryResultOutputType::FORCE_MATERIALIZED), "FORCE_MATERIALIZED" }, + { static_cast(QueryResultOutputType::ALLOW_STREAMING), "ALLOW_STREAMING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(QueryResultOutputType value) { + return StringUtil::EnumToString(GetQueryResultOutputTypeValues(), 2, "QueryResultOutputType", static_cast(value)); +} + +template<> +QueryResultOutputType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetQueryResultOutputTypeValues(), 2, "QueryResultOutputType", value)); } const StringUtil::EnumStringLiteral *GetQueryResultTypeValues() { @@ -3630,6 +3821,24 @@ QueryResultType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetQueryResultTypeValues(), 4, "QueryResultType", value)); } +const StringUtil::EnumStringLiteral *GetRecoveryModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(RecoveryMode::DEFAULT), "DEFAULT" }, + { static_cast(RecoveryMode::NO_WAL_WRITES), "NO_WAL_WRITES" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(RecoveryMode value) { + return StringUtil::EnumToString(GetRecoveryModeValues(), 2, "RecoveryMode", static_cast(value)); +} + +template<> +RecoveryMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetRecoveryModeValues(), 2, "RecoveryMode", value)); +} + const StringUtil::EnumStringLiteral *GetRelationTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(RelationType::INVALID_RELATION), "INVALID_RELATION" }, @@ -3659,19 +3868,20 @@ const StringUtil::EnumStringLiteral *GetRelationTypeValues() { { static_cast(RelationType::VIEW_RELATION), "VIEW_RELATION" }, { static_cast(RelationType::QUERY_RELATION), "QUERY_RELATION" }, { static_cast(RelationType::DELIM_JOIN_RELATION), "DELIM_JOIN_RELATION" }, - { static_cast(RelationType::DELIM_GET_RELATION), "DELIM_GET_RELATION" } + { static_cast(RelationType::DELIM_GET_RELATION), "DELIM_GET_RELATION" }, + { static_cast(RelationType::EXTENSION_RELATION), "EXTENSION_RELATION" } }; return values; } template<> const char* EnumUtil::ToChars(RelationType value) { - return StringUtil::EnumToString(GetRelationTypeValues(), 28, "RelationType", static_cast(value)); + return StringUtil::EnumToString(GetRelationTypeValues(), 29, "RelationType", static_cast(value)); } template<> RelationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetRelationTypeValues(), 28, "RelationType", value)); + return static_cast(StringUtil::StringToEnum(GetRelationTypeValues(), 29, "RelationType", value)); } const StringUtil::EnumStringLiteral *GetRenderModeValues() { @@ -3993,19 +4203,21 @@ const StringUtil::EnumStringLiteral *GetSimplifiedTokenTypeValues() { { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR), "SIMPLIFIED_TOKEN_OPERATOR" }, { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD), "SIMPLIFIED_TOKEN_KEYWORD" }, { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT), "SIMPLIFIED_TOKEN_COMMENT" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR), "SIMPLIFIED_TOKEN_ERROR" } + { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR), "SIMPLIFIED_TOKEN_ERROR" }, + { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_EMPHASIS), "SIMPLIFIED_TOKEN_ERROR_EMPHASIS" }, + { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_SUGGESTION), "SIMPLIFIED_TOKEN_ERROR_SUGGESTION" } }; return values; } template<> const char* EnumUtil::ToChars(SimplifiedTokenType value) { - return StringUtil::EnumToString(GetSimplifiedTokenTypeValues(), 7, "SimplifiedTokenType", static_cast(value)); + return StringUtil::EnumToString(GetSimplifiedTokenTypeValues(), 9, "SimplifiedTokenType", static_cast(value)); } template<> SimplifiedTokenType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSimplifiedTokenTypeValues(), 7, "SimplifiedTokenType", value)); + return static_cast(StringUtil::StringToEnum(GetSimplifiedTokenTypeValues(), 9, "SimplifiedTokenType", value)); } const StringUtil::EnumStringLiteral *GetSinkCombineResultTypeValues() { @@ -4220,19 +4432,21 @@ const StringUtil::EnumStringLiteral *GetStatisticsTypeValues() { { static_cast(StatisticsType::LIST_STATS), "LIST_STATS" }, { static_cast(StatisticsType::STRUCT_STATS), "STRUCT_STATS" }, { static_cast(StatisticsType::BASE_STATS), "BASE_STATS" }, - { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" } + { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" }, + { static_cast(StatisticsType::GEOMETRY_STATS), "GEOMETRY_STATS" }, + { static_cast(StatisticsType::VARIANT_STATS), "VARIANT_STATS" } }; return values; } template<> const char* EnumUtil::ToChars(StatisticsType value) { - return StringUtil::EnumToString(GetStatisticsTypeValues(), 6, "StatisticsType", static_cast(value)); + return StringUtil::EnumToString(GetStatisticsTypeValues(), 8, "StatisticsType", static_cast(value)); } template<> StatisticsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 6, "StatisticsType", value)); + return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 8, "StatisticsType", value)); } const StringUtil::EnumStringLiteral *GetStatsInfoValues() { @@ -4401,19 +4615,20 @@ const StringUtil::EnumStringLiteral *GetTableFilterTypeValues() { { static_cast(TableFilterType::OPTIONAL_FILTER), "OPTIONAL_FILTER" }, { static_cast(TableFilterType::IN_FILTER), "IN_FILTER" }, { static_cast(TableFilterType::DYNAMIC_FILTER), "DYNAMIC_FILTER" }, - { static_cast(TableFilterType::EXPRESSION_FILTER), "EXPRESSION_FILTER" } + { static_cast(TableFilterType::EXPRESSION_FILTER), "EXPRESSION_FILTER" }, + { static_cast(TableFilterType::BLOOM_FILTER), "BLOOM_FILTER" } }; return values; } template<> const char* EnumUtil::ToChars(TableFilterType value) { - return StringUtil::EnumToString(GetTableFilterTypeValues(), 10, "TableFilterType", static_cast(value)); + return StringUtil::EnumToString(GetTableFilterTypeValues(), 11, "TableFilterType", static_cast(value)); } template<> TableFilterType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 10, "TableFilterType", value)); + return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 11, "TableFilterType", value)); } const StringUtil::EnumStringLiteral *GetTablePartitionInfoValues() { @@ -4806,6 +5021,7 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { { static_cast(VariantLogicalType::ARRAY), "ARRAY" }, { static_cast(VariantLogicalType::BIGNUM), "BIGNUM" }, { static_cast(VariantLogicalType::BITSTRING), "BITSTRING" }, + { static_cast(VariantLogicalType::GEOMETRY), "GEOMETRY" }, { static_cast(VariantLogicalType::ENUM_SIZE), "ENUM_SIZE" } }; return values; @@ -4813,12 +5029,52 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { template<> const char* EnumUtil::ToChars(VariantLogicalType value) { - return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", static_cast(value)); + return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", static_cast(value)); } template<> VariantLogicalType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", value)); + return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", value)); +} + +const StringUtil::EnumStringLiteral *GetVariantStatsShreddingStateValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(VariantStatsShreddingState::UNINITIALIZED), "UNINITIALIZED" }, + { static_cast(VariantStatsShreddingState::NOT_SHREDDED), "NOT_SHREDDED" }, + { static_cast(VariantStatsShreddingState::SHREDDED), "SHREDDED" }, + { static_cast(VariantStatsShreddingState::INCONSISTENT), "INCONSISTENT" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(VariantStatsShreddingState value) { + return StringUtil::EnumToString(GetVariantStatsShreddingStateValues(), 4, "VariantStatsShreddingState", static_cast(value)); +} + +template<> +VariantStatsShreddingState EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetVariantStatsShreddingStateValues(), 4, "VariantStatsShreddingState", value)); +} + +const StringUtil::EnumStringLiteral *GetVariantValueTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(VariantValueType::PRIMITIVE), "PRIMITIVE" }, + { static_cast(VariantValueType::OBJECT), "OBJECT" }, + { static_cast(VariantValueType::ARRAY), "ARRAY" }, + { static_cast(VariantValueType::MISSING), "MISSING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(VariantValueType value) { + return StringUtil::EnumToString(GetVariantValueTypeValues(), 4, "VariantValueType", static_cast(value)); +} + +template<> +VariantValueType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetVariantValueTypeValues(), 4, "VariantValueType", value)); } const StringUtil::EnumStringLiteral *GetVectorAuxiliaryDataTypeValues() { @@ -4931,6 +5187,26 @@ VerifyExistenceType EnumUtil::FromString(const char *value) return static_cast(StringUtil::StringToEnum(GetVerifyExistenceTypeValues(), 3, "VerifyExistenceType", value)); } +const StringUtil::EnumStringLiteral *GetVertexTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(VertexType::XY), "XY" }, + { static_cast(VertexType::XYZ), "XYZ" }, + { static_cast(VertexType::XYM), "XYM" }, + { static_cast(VertexType::XYZM), "XYZM" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(VertexType value) { + return StringUtil::EnumToString(GetVertexTypeValues(), 4, "VertexType", static_cast(value)); +} + +template<> +VertexType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetVertexTypeValues(), 4, "VertexType", value)); +} + const StringUtil::EnumStringLiteral *GetWALTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(WALType::INVALID), "INVALID" }, diff --git a/src/duckdb/src/common/enums/compression_type.cpp b/src/duckdb/src/common/enums/compression_type.cpp index ec551eff1..427cfbe91 100644 --- a/src/duckdb/src/common/enums/compression_type.cpp +++ b/src/duckdb/src/common/enums/compression_type.cpp @@ -17,25 +17,60 @@ vector ListCompressionTypes(void) { return compression_types; } -bool CompressionTypeIsDeprecated(CompressionType compression_type, optional_ptr storage_manager) { - vector types({CompressionType::COMPRESSION_PATAS, CompressionType::COMPRESSION_CHIMP}); - if (storage_manager) { - if (storage_manager->GetStorageVersion() >= 5) { - //! NOTE: storage_manager is an optional_ptr because it's called from ForceCompressionSetting, which doesn't - //! have guaranteed access to a StorageManager The introduction of DICT_FSST deprecates Dictionary and FSST - //! compression methods - types.emplace_back(CompressionType::COMPRESSION_DICTIONARY); - types.emplace_back(CompressionType::COMPRESSION_FSST); - } else { - types.emplace_back(CompressionType::COMPRESSION_DICT_FSST); - } +namespace { +struct CompressionMethodRequirements { + CompressionType type; + optional_idx minimum_storage_version; + optional_idx maximum_storage_version; +}; +} // namespace + +CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compression_type, + optional_ptr storage_manager) { + //! Max storage compatibility + vector candidates({{CompressionType::COMPRESSION_PATAS, optional_idx(), 0}, + {CompressionType::COMPRESSION_CHIMP, optional_idx(), 0}, + {CompressionType::COMPRESSION_DICTIONARY, 0, 4}, + {CompressionType::COMPRESSION_FSST, 0, 4}, + {CompressionType::COMPRESSION_DICT_FSST, 5, optional_idx()}}); + + optional_idx current_storage_version; + if (storage_manager && storage_manager->HasStorageVersion()) { + current_storage_version = storage_manager->GetStorageVersion(); } - for (auto &type : types) { - if (type == compression_type) { - return true; + for (auto &candidate : candidates) { + auto &type = candidate.type; + if (type != compression_type) { + continue; + } + auto &min = candidate.minimum_storage_version; + auto &max = candidate.maximum_storage_version; + + if (!min.IsValid()) { + //! Used to signal: always deprecated + return CompressionAvailabilityResult::Deprecated(); + } + + if (!current_storage_version.IsValid()) { + //! Can't determine in this call whether it's available or not, default to available + return CompressionAvailabilityResult(); + } + + auto current_version = current_storage_version.GetIndex(); + D_ASSERT(min.IsValid()); + if (min.GetIndex() > current_version) { + //! Minimum required storage version is higher than the current storage version, this method isn't available + //! yet + return CompressionAvailabilityResult::NotAvailableYet(); + } + if (max.IsValid() && max.GetIndex() < current_version) { + //! Maximum supported storage version is lower than the current storage version, this method is no longer + //! available + return CompressionAvailabilityResult::Deprecated(); } + return CompressionAvailabilityResult(); } - return false; + return CompressionAvailabilityResult(); } CompressionType CompressionTypeFromString(const string &str) { diff --git a/src/duckdb/src/common/enums/expression_type.cpp b/src/duckdb/src/common/enums/expression_type.cpp index 6dbdf2612..f1375d3a1 100644 --- a/src/duckdb/src/common/enums/expression_type.cpp +++ b/src/duckdb/src/common/enums/expression_type.cpp @@ -287,6 +287,12 @@ ExpressionType NegateComparisonExpression(ExpressionType type) { case ExpressionType::COMPARE_GREATERTHANOREQUALTO: negated_type = ExpressionType::COMPARE_LESSTHAN; break; + case ExpressionType::COMPARE_DISTINCT_FROM: + negated_type = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + break; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + negated_type = ExpressionType::COMPARE_DISTINCT_FROM; + break; default: throw InternalException("Unsupported comparison type in negation"); } diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp index 866049251..f7788ef9c 100644 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ b/src/duckdb/src/common/enums/metric_type.cpp @@ -1,249 +1,245 @@ -//------------------------------------------------------------------------- -// DuckDB -// -// -// duckdb/common/enums/metrics_type.hpp -// // This file is automatically generated by scripts/generate_metric_enums.py // Do not edit this file manually, your changes will be overwritten -//------------------------------------------------------------------------- #include "duckdb/common/enums/metric_type.hpp" +#include "duckdb/common/enum_util.hpp" + namespace duckdb { +profiler_settings_t MetricsUtils::GetAllMetrics() { + profiler_settings_t result; + for (auto metric = START_CORE; metric <= END_PHASE_TIMING; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +profiler_settings_t MetricsUtils::GetMetricsByGroupType(MetricGroup type) { + switch(type) { + case MetricGroup::ALL: + return GetAllMetrics(); + case MetricGroup::CORE: + return GetCoreMetrics(); + case MetricGroup::DEFAULT: + return GetDefaultMetrics(); + case MetricGroup::EXECUTION: + return GetExecutionMetrics(); + case MetricGroup::FILE: + return GetFileMetrics(); + case MetricGroup::OPERATOR: + return GetOperatorMetrics(); + case MetricGroup::OPTIMIZER: + return GetOptimizerMetrics(); + case MetricGroup::PHASE_TIMING: + return GetPhaseTimingMetrics(); + default: + throw InternalException("The MetricGroup passed is invalid"); + } +} + +profiler_settings_t MetricsUtils::GetCoreMetrics() { + profiler_settings_t result; + for (auto metric = START_CORE; metric <= END_CORE; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsCoreMetric(MetricType type) { + return static_cast(type) <= END_CORE; +} + +profiler_settings_t MetricsUtils::GetDefaultMetrics() { + return { + MetricType::ATTACH_LOAD_STORAGE_LATENCY, + MetricType::ATTACH_REPLAY_WAL_LATENCY, + MetricType::BLOCKED_THREAD_TIME, + MetricType::CHECKPOINT_LATENCY, + MetricType::COMMIT_LOCAL_STORAGE_LATENCY, + MetricType::CPU_TIME, + MetricType::CUMULATIVE_CARDINALITY, + MetricType::CUMULATIVE_ROWS_SCANNED, + MetricType::EXTRA_INFO, + MetricType::LATENCY, + MetricType::OPERATOR_CARDINALITY, + MetricType::OPERATOR_NAME, + MetricType::OPERATOR_ROWS_SCANNED, + MetricType::OPERATOR_TIMING, + MetricType::OPERATOR_TYPE, + MetricType::QUERY_NAME, + MetricType::RESULT_SET_SIZE, + MetricType::ROWS_RETURNED, + MetricType::SYSTEM_PEAK_BUFFER_MEMORY, + MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, + MetricType::TOTAL_BYTES_READ, + MetricType::TOTAL_BYTES_WRITTEN, + MetricType::TOTAL_MEMORY_ALLOCATED, + MetricType::WAITING_TO_ATTACH_LATENCY, + MetricType::WAL_REPLAY_ENTRY_COUNT, + MetricType::WRITE_TO_WAL_LATENCY, + }; +} + +bool MetricsUtils::IsDefaultMetric(MetricType type) { + switch(type) { + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricType::ATTACH_REPLAY_WAL_LATENCY: + case MetricType::BLOCKED_THREAD_TIME: + case MetricType::CHECKPOINT_LATENCY: + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: + case MetricType::CPU_TIME: + case MetricType::CUMULATIVE_CARDINALITY: + case MetricType::CUMULATIVE_ROWS_SCANNED: + case MetricType::EXTRA_INFO: + case MetricType::LATENCY: + case MetricType::OPERATOR_CARDINALITY: + case MetricType::OPERATOR_NAME: + case MetricType::OPERATOR_ROWS_SCANNED: + case MetricType::OPERATOR_TIMING: + case MetricType::OPERATOR_TYPE: + case MetricType::QUERY_NAME: + case MetricType::RESULT_SET_SIZE: + case MetricType::ROWS_RETURNED: + case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: + case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricType::TOTAL_BYTES_READ: + case MetricType::TOTAL_BYTES_WRITTEN: + case MetricType::TOTAL_MEMORY_ALLOCATED: + case MetricType::WAITING_TO_ATTACH_LATENCY: + case MetricType::WAL_REPLAY_ENTRY_COUNT: + case MetricType::WRITE_TO_WAL_LATENCY: + return true; + default: + return false; + } +} + +profiler_settings_t MetricsUtils::GetExecutionMetrics() { + profiler_settings_t result; + for (auto metric = START_EXECUTION; metric <= END_EXECUTION; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsExecutionMetric(MetricType type) { + return static_cast(type) >= START_EXECUTION && static_cast(type) <= END_EXECUTION; +} + +profiler_settings_t MetricsUtils::GetFileMetrics() { + profiler_settings_t result; + for (auto metric = START_FILE; metric <= END_FILE; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsFileMetric(MetricType type) { + return static_cast(type) >= START_FILE && static_cast(type) <= END_FILE; +} + +profiler_settings_t MetricsUtils::GetOperatorMetrics() { + profiler_settings_t result; + for (auto metric = START_OPERATOR; metric <= END_OPERATOR; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsOperatorMetric(MetricType type) { + return static_cast(type) >= START_OPERATOR && static_cast(type) <= END_OPERATOR; +} + profiler_settings_t MetricsUtils::GetOptimizerMetrics() { - return { - MetricsType::OPTIMIZER_EXPRESSION_REWRITER, - MetricsType::OPTIMIZER_FILTER_PULLUP, - MetricsType::OPTIMIZER_FILTER_PUSHDOWN, - MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP, - MetricsType::OPTIMIZER_CTE_FILTER_PUSHER, - MetricsType::OPTIMIZER_REGEX_RANGE, - MetricsType::OPTIMIZER_IN_CLAUSE, - MetricsType::OPTIMIZER_JOIN_ORDER, - MetricsType::OPTIMIZER_DELIMINATOR, - MetricsType::OPTIMIZER_UNNEST_REWRITER, - MetricsType::OPTIMIZER_UNUSED_COLUMNS, - MetricsType::OPTIMIZER_STATISTICS_PROPAGATION, - MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS, - MetricsType::OPTIMIZER_COMMON_AGGREGATE, - MetricsType::OPTIMIZER_COLUMN_LIFETIME, - MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, - MetricsType::OPTIMIZER_LIMIT_PUSHDOWN, - MetricsType::OPTIMIZER_TOP_N, - MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION, - MetricsType::OPTIMIZER_DUPLICATE_GROUPS, - MetricsType::OPTIMIZER_REORDER_FILTER, - MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN, - MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN, - MetricsType::OPTIMIZER_EXTENSION, - MetricsType::OPTIMIZER_MATERIALIZED_CTE, - MetricsType::OPTIMIZER_SUM_REWRITER, - MetricsType::OPTIMIZER_LATE_MATERIALIZATION, - MetricsType::OPTIMIZER_CTE_INLINING, - }; + profiler_settings_t result; + for (auto metric = START_OPTIMIZER; metric <= END_OPTIMIZER; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsOptimizerMetric(MetricType type) { + return static_cast(type) >= START_OPTIMIZER && static_cast(type) <= END_OPTIMIZER; +} + + +MetricType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { + if (type == OptimizerType::INVALID) { + throw InternalException("Invalid OptimizerType: INVALID"); + } + + const auto base_opt = static_cast(OptimizerType::EXPRESSION_REWRITER); + const auto idx = static_cast(type) - base_opt; + + const auto metric_u8 = static_cast(START_OPTIMIZER + idx); + if (metric_u8 < START_OPTIMIZER || metric_u8 > END_OPTIMIZER) { + throw InternalException("OptimizerType out of MetricType optimizer range"); + } + return static_cast(metric_u8); +} + +OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricType type) { + const auto metric_u8 = static_cast(type); + if (!IsOptimizerMetric(type)) { + throw InternalException("MetricType is not an optimizer metric"); + } + + const auto idx = static_cast(metric_u8 - START_OPTIMIZER); + const auto result = static_cast(OptimizerType::EXPRESSION_REWRITER) + idx; + return static_cast(result); } profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { - return { - MetricsType::ALL_OPTIMIZERS, - MetricsType::CUMULATIVE_OPTIMIZER_TIMING, - MetricsType::PLANNER, - MetricsType::PLANNER_BINDING, - MetricsType::PHYSICAL_PLANNER, - MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, - MetricsType::PHYSICAL_PLANNER_CREATE_PLAN, - }; -} - -MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { - switch(type) { - case OptimizerType::EXPRESSION_REWRITER: - return MetricsType::OPTIMIZER_EXPRESSION_REWRITER; - case OptimizerType::FILTER_PULLUP: - return MetricsType::OPTIMIZER_FILTER_PULLUP; - case OptimizerType::FILTER_PUSHDOWN: - return MetricsType::OPTIMIZER_FILTER_PUSHDOWN; - case OptimizerType::EMPTY_RESULT_PULLUP: - return MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP; - case OptimizerType::CTE_FILTER_PUSHER: - return MetricsType::OPTIMIZER_CTE_FILTER_PUSHER; - case OptimizerType::REGEX_RANGE: - return MetricsType::OPTIMIZER_REGEX_RANGE; - case OptimizerType::IN_CLAUSE: - return MetricsType::OPTIMIZER_IN_CLAUSE; - case OptimizerType::JOIN_ORDER: - return MetricsType::OPTIMIZER_JOIN_ORDER; - case OptimizerType::DELIMINATOR: - return MetricsType::OPTIMIZER_DELIMINATOR; - case OptimizerType::UNNEST_REWRITER: - return MetricsType::OPTIMIZER_UNNEST_REWRITER; - case OptimizerType::UNUSED_COLUMNS: - return MetricsType::OPTIMIZER_UNUSED_COLUMNS; - case OptimizerType::STATISTICS_PROPAGATION: - return MetricsType::OPTIMIZER_STATISTICS_PROPAGATION; - case OptimizerType::COMMON_SUBEXPRESSIONS: - return MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS; - case OptimizerType::COMMON_AGGREGATE: - return MetricsType::OPTIMIZER_COMMON_AGGREGATE; - case OptimizerType::COLUMN_LIFETIME: - return MetricsType::OPTIMIZER_COLUMN_LIFETIME; - case OptimizerType::BUILD_SIDE_PROBE_SIDE: - return MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE; - case OptimizerType::LIMIT_PUSHDOWN: - return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; - case OptimizerType::TOP_N: - return MetricsType::OPTIMIZER_TOP_N; - case OptimizerType::COMPRESSED_MATERIALIZATION: - return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; - case OptimizerType::DUPLICATE_GROUPS: - return MetricsType::OPTIMIZER_DUPLICATE_GROUPS; - case OptimizerType::REORDER_FILTER: - return MetricsType::OPTIMIZER_REORDER_FILTER; - case OptimizerType::SAMPLING_PUSHDOWN: - return MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN; - case OptimizerType::JOIN_FILTER_PUSHDOWN: - return MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN; - case OptimizerType::EXTENSION: - return MetricsType::OPTIMIZER_EXTENSION; - case OptimizerType::MATERIALIZED_CTE: - return MetricsType::OPTIMIZER_MATERIALIZED_CTE; - case OptimizerType::SUM_REWRITER: - return MetricsType::OPTIMIZER_SUM_REWRITER; - case OptimizerType::LATE_MATERIALIZATION: - return MetricsType::OPTIMIZER_LATE_MATERIALIZATION; - case OptimizerType::CTE_INLINING: - return MetricsType::OPTIMIZER_CTE_INLINING; - default: - throw InternalException("OptimizerType %s cannot be converted to a MetricsType", EnumUtil::ToString(type)); - }; -} - -OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { - switch(type) { - case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: - return OptimizerType::EXPRESSION_REWRITER; - case MetricsType::OPTIMIZER_FILTER_PULLUP: - return OptimizerType::FILTER_PULLUP; - case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: - return OptimizerType::FILTER_PUSHDOWN; - case MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP: - return OptimizerType::EMPTY_RESULT_PULLUP; - case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: - return OptimizerType::CTE_FILTER_PUSHER; - case MetricsType::OPTIMIZER_REGEX_RANGE: - return OptimizerType::REGEX_RANGE; - case MetricsType::OPTIMIZER_IN_CLAUSE: - return OptimizerType::IN_CLAUSE; - case MetricsType::OPTIMIZER_JOIN_ORDER: - return OptimizerType::JOIN_ORDER; - case MetricsType::OPTIMIZER_DELIMINATOR: - return OptimizerType::DELIMINATOR; - case MetricsType::OPTIMIZER_UNNEST_REWRITER: - return OptimizerType::UNNEST_REWRITER; - case MetricsType::OPTIMIZER_UNUSED_COLUMNS: - return OptimizerType::UNUSED_COLUMNS; - case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: - return OptimizerType::STATISTICS_PROPAGATION; - case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - return OptimizerType::COMMON_SUBEXPRESSIONS; - case MetricsType::OPTIMIZER_COMMON_AGGREGATE: - return OptimizerType::COMMON_AGGREGATE; - case MetricsType::OPTIMIZER_COLUMN_LIFETIME: - return OptimizerType::COLUMN_LIFETIME; - case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - return OptimizerType::BUILD_SIDE_PROBE_SIDE; - case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: - return OptimizerType::LIMIT_PUSHDOWN; - case MetricsType::OPTIMIZER_TOP_N: - return OptimizerType::TOP_N; - case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - return OptimizerType::COMPRESSED_MATERIALIZATION; - case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: - return OptimizerType::DUPLICATE_GROUPS; - case MetricsType::OPTIMIZER_REORDER_FILTER: - return OptimizerType::REORDER_FILTER; - case MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN: - return OptimizerType::SAMPLING_PUSHDOWN; - case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - return OptimizerType::JOIN_FILTER_PUSHDOWN; - case MetricsType::OPTIMIZER_EXTENSION: - return OptimizerType::EXTENSION; - case MetricsType::OPTIMIZER_MATERIALIZED_CTE: - return OptimizerType::MATERIALIZED_CTE; - case MetricsType::OPTIMIZER_SUM_REWRITER: - return OptimizerType::SUM_REWRITER; - case MetricsType::OPTIMIZER_LATE_MATERIALIZATION: - return OptimizerType::LATE_MATERIALIZATION; - case MetricsType::OPTIMIZER_CTE_INLINING: - return OptimizerType::CTE_INLINING; - default: - return OptimizerType::INVALID; - }; -} - -bool MetricsUtils::IsOptimizerMetric(MetricsType type) { - switch(type) { - case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: - case MetricsType::OPTIMIZER_FILTER_PULLUP: - case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: - case MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP: - case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: - case MetricsType::OPTIMIZER_REGEX_RANGE: - case MetricsType::OPTIMIZER_IN_CLAUSE: - case MetricsType::OPTIMIZER_JOIN_ORDER: - case MetricsType::OPTIMIZER_DELIMINATOR: - case MetricsType::OPTIMIZER_UNNEST_REWRITER: - case MetricsType::OPTIMIZER_UNUSED_COLUMNS: - case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: - case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - case MetricsType::OPTIMIZER_COMMON_AGGREGATE: - case MetricsType::OPTIMIZER_COLUMN_LIFETIME: - case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: - case MetricsType::OPTIMIZER_TOP_N: - case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: - case MetricsType::OPTIMIZER_REORDER_FILTER: - case MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN: - case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - case MetricsType::OPTIMIZER_EXTENSION: - case MetricsType::OPTIMIZER_MATERIALIZED_CTE: - case MetricsType::OPTIMIZER_SUM_REWRITER: - case MetricsType::OPTIMIZER_LATE_MATERIALIZATION: - case MetricsType::OPTIMIZER_CTE_INLINING: - return true; - default: - return false; - }; -} - -bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { - switch(type) { - case MetricsType::ALL_OPTIMIZERS: - case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricsType::PLANNER: - case MetricsType::PLANNER_BINDING: - case MetricsType::PHYSICAL_PLANNER: - case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: - return true; - default: - return false; - }; -} - -bool MetricsUtils::IsQueryGlobalMetric(MetricsType type) { - switch(type) { - case MetricsType::BLOCKED_THREAD_TIME: - case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: - return true; - default: - return false; - }; -} - -} // namespace duckdb + profiler_settings_t result; + for (auto metric = START_PHASE_TIMING; metric <= END_PHASE_TIMING; metric++) { + result.insert(static_cast(metric)); + } + return result; +} + +bool MetricsUtils::IsPhaseTimingMetric(MetricType type) { + return static_cast(type) >= START_PHASE_TIMING && static_cast(type) <= END_PHASE_TIMING; +} + +profiler_settings_t MetricsUtils::GetRootScopeMetrics() { + return { + MetricType::ATTACH_LOAD_STORAGE_LATENCY, + MetricType::ATTACH_REPLAY_WAL_LATENCY, + MetricType::BLOCKED_THREAD_TIME, + MetricType::CHECKPOINT_LATENCY, + MetricType::COMMIT_LOCAL_STORAGE_LATENCY, + MetricType::LATENCY, + MetricType::QUERY_NAME, + MetricType::ROWS_RETURNED, + MetricType::TOTAL_BYTES_READ, + MetricType::TOTAL_BYTES_WRITTEN, + MetricType::TOTAL_MEMORY_ALLOCATED, + MetricType::WAITING_TO_ATTACH_LATENCY, + MetricType::WAL_REPLAY_ENTRY_COUNT, + MetricType::WRITE_TO_WAL_LATENCY, + }; +} + +bool MetricsUtils::IsRootScopeMetric(MetricType type) { + switch(type) { + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricType::ATTACH_REPLAY_WAL_LATENCY: + case MetricType::BLOCKED_THREAD_TIME: + case MetricType::CHECKPOINT_LATENCY: + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: + case MetricType::LATENCY: + case MetricType::QUERY_NAME: + case MetricType::ROWS_RETURNED: + case MetricType::TOTAL_BYTES_READ: + case MetricType::TOTAL_BYTES_WRITTEN: + case MetricType::TOTAL_MEMORY_ALLOCATED: + case MetricType::WAITING_TO_ATTACH_LATENCY: + case MetricType::WAL_REPLAY_ENTRY_COUNT: + case MetricType::WRITE_TO_WAL_LATENCY: + return true; + default: + return false; + } +} + +} diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index b0d669500..f62af9626 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/optimizer/optimizer.hpp" namespace duckdb { @@ -28,7 +29,9 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"common_aggregate", OptimizerType::COMMON_AGGREGATE}, {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, {"limit_pushdown", OptimizerType::LIMIT_PUSHDOWN}, + {"row_group_pruner", OptimizerType::ROW_GROUP_PRUNER}, {"top_n", OptimizerType::TOP_N}, + {"top_n_window_elimination", OptimizerType::TOP_N_WINDOW_ELIMINATION}, {"build_side_probe_side", OptimizerType::BUILD_SIDE_PROBE_SIDE}, {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, @@ -40,6 +43,8 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"sum_rewriter", OptimizerType::SUM_REWRITER}, {"late_materialization", OptimizerType::LATE_MATERIALIZATION}, {"cte_inlining", OptimizerType::CTE_INLINING}, + {"common_subplan", OptimizerType::COMMON_SUBPLAN}, + {"join_elimination", OptimizerType::JOIN_ELIMINATION}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/enums/relation_type.cpp b/src/duckdb/src/common/enums/relation_type.cpp index 4f58ed7c4..dc02b8970 100644 --- a/src/duckdb/src/common/enums/relation_type.cpp +++ b/src/duckdb/src/common/enums/relation_type.cpp @@ -61,6 +61,8 @@ string RelationTypeToString(RelationType type) { return "VIEW_RELATION"; case RelationType::QUERY_RELATION: return "QUERY_RELATION"; + case RelationType::EXTENSION_RELATION: + return "EXTENSION_RELATION"; case RelationType::INVALID_RELATION: break; } diff --git a/src/duckdb/src/common/enums/statement_type.cpp b/src/duckdb/src/common/enums/statement_type.cpp index 20251a934..643df0076 100644 --- a/src/duckdb/src/common/enums/statement_type.cpp +++ b/src/duckdb/src/common/enums/statement_type.cpp @@ -92,11 +92,18 @@ void StatementProperties::RegisterDBRead(Catalog &catalog, ClientContext &contex read_databases[catalog.GetName()] = catalog_identity; } -void StatementProperties::RegisterDBModify(Catalog &catalog, ClientContext &context) { +void StatementProperties::RegisterDBModify(Catalog &catalog, ClientContext &context, + DatabaseModificationType modification) { auto catalog_identity = CatalogIdentity {catalog.GetOid(), catalog.GetCatalogVersion(context)}; - D_ASSERT(modified_databases.count(catalog.GetName()) == 0 || - modified_databases[catalog.GetName()] == catalog_identity); - modified_databases[catalog.GetName()] = catalog_identity; + auto entry = modified_databases.insert(make_pair(catalog.GetName(), ModificationInfo())); + if (entry.second) { + // new entry - set the identity + entry.first->second.identity = catalog_identity; + } else { + // existing entry - verify this has the same identity + D_ASSERT(entry.first->second.identity == catalog_identity); + } + entry.first->second.modifications |= modification; } } // namespace duckdb diff --git a/src/duckdb/src/common/error_data.cpp b/src/duckdb/src/common/error_data.cpp index 2ddf94af6..44f70a085 100644 --- a/src/duckdb/src/common/error_data.cpp +++ b/src/duckdb/src/common/error_data.cpp @@ -24,7 +24,6 @@ ErrorData::ErrorData(ExceptionType type, const string &message) ErrorData::ErrorData(const string &message) : initialized(true), type(ExceptionType::INVALID), raw_message(string()), final_message(string()) { - // parse the constructed JSON if (message.empty() || message[0] != '{') { // not JSON! Use the message as a raw Exception message and leave type as uninitialized @@ -61,12 +60,12 @@ string ErrorData::ConstructFinalMessage() const { error = Exception::ExceptionTypeToString(type) + " "; } error += "Error: " + raw_message; - if (type == ExceptionType::INTERNAL) { + if (type == ExceptionType::INTERNAL || type == ExceptionType::FATAL) { error += "\nThis error signals an assertion failure within DuckDB. This usually occurs due to " "unexpected conditions or errors in the program's logic.\nFor more information, see " "https://duckdb.org/docs/stable/dev/internal_errors"; - // Ensure that we print the stack trace for internal exceptions. + // Ensure that we print the stack trace for internal and fatal exceptions. auto entry = extra_info.find("stack_trace_pointers"); if (entry != extra_info.end()) { auto stack_trace = StackTrace::ResolveStacktraceSymbols(entry->second); @@ -80,9 +79,9 @@ void ErrorData::Throw(const string &prepended_message) const { D_ASSERT(initialized); if (!prepended_message.empty()) { string new_message = prepended_message + raw_message; - throw Exception(type, new_message, extra_info); + throw Exception(extra_info, type, new_message); } else { - throw Exception(type, raw_message, extra_info); + throw Exception(extra_info, type, raw_message); } } diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp index 2012c1fcc..ce22091c8 100644 --- a/src/duckdb/src/common/exception.cpp +++ b/src/duckdb/src/common/exception.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/types.hpp" #include "duckdb/common/exception/list.hpp" #include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/planner/expression.hpp" #ifdef DUCKDB_CRASH_ON_ASSERT @@ -19,17 +20,17 @@ Exception::Exception(ExceptionType exception_type, const string &message) : std::runtime_error(ToJSON(exception_type, message)) { } -Exception::Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info) - : std::runtime_error(ToJSON(exception_type, message, extra_info)) { +Exception::Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message) + : std::runtime_error(ToJSON(extra_info, exception_type, message)) { } string Exception::ToJSON(ExceptionType type, const string &message) { unordered_map extra_info; - return ToJSON(type, message, extra_info); + return ToJSON(extra_info, type, message); } -string Exception::ToJSON(ExceptionType type, const string &message, const unordered_map &extra_info) { +string Exception::ToJSON(const unordered_map &extra_info, ExceptionType type, const string &message) { #ifndef DUCKDB_DEBUG_STACKTRACE // by default we only enable stack traces for internal exceptions if (type == ExceptionType::INTERNAL || type == ExceptionType::FATAL) @@ -240,9 +241,8 @@ TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const Lo TypeMismatchException::TypeMismatchException(optional_idx error_location, const LogicalType &type_1, const LogicalType &type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg, - Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::MISMATCH_TYPE, + "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg) { } TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { @@ -306,8 +306,12 @@ DependencyException::DependencyException(const string &msg) : Exception(Exceptio IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { } -IOException::IOException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::IO, msg, extra_info) { +IOException::IOException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::IO, msg) { +} + +NotImplementedException::NotImplementedException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::NOT_IMPLEMENTED, msg) { } MissingExtensionException::MissingExtensionException(const string &msg) @@ -330,29 +334,35 @@ InterruptException::InterruptException() : Exception(ExceptionType::INTERRUPT, " } FatalException::FatalException(ExceptionType type, const string &msg) : Exception(type, msg) { + // FIXME: Make any log context available to add error logging. } InternalException::InternalException(const string &msg) : Exception(ExceptionType::INTERNAL, msg) { + // FIXME: Make any log context available to add error logging. #ifdef DUCKDB_CRASH_ON_ASSERT Printer::Print("ABORT THROWN BY INTERNAL EXCEPTION: " + msg + "\n" + StackTrace::GetStackTrace()); abort(); #endif } +InternalException::InternalException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::INTERNAL, msg) { +} + InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { } -InvalidInputException::InvalidInputException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_INPUT, msg, extra_info) { +InvalidInputException::InvalidInputException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::INVALID_INPUT, msg) { } InvalidConfigurationException::InvalidConfigurationException(const string &msg) : Exception(ExceptionType::INVALID_CONFIGURATION, msg) { } -InvalidConfigurationException::InvalidConfigurationException(const string &msg, - const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_CONFIGURATION, msg, extra_info) { +InvalidConfigurationException::InvalidConfigurationException(const unordered_map &extra_info, + const string &msg) + : Exception(extra_info, ExceptionType::INVALID_CONFIGURATION, msg) { } OutOfMemoryException::OutOfMemoryException(const string &msg) diff --git a/src/duckdb/src/common/exception/binder_exception.cpp b/src/duckdb/src/common/exception/binder_exception.cpp index 62dca06fb..aa9a9459e 100644 --- a/src/duckdb/src/common/exception/binder_exception.cpp +++ b/src/duckdb/src/common/exception/binder_exception.cpp @@ -7,8 +7,8 @@ namespace duckdb { BinderException::BinderException(const string &msg) : Exception(ExceptionType::BINDER, msg) { } -BinderException::BinderException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::BINDER, msg, extra_info) { +BinderException::BinderException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::BINDER, msg) { } BinderException BinderException::ColumnNotFound(const string &name, const vector &similar_bindings, @@ -18,9 +18,13 @@ BinderException BinderException::ColumnNotFound(const string &name, const vector extra_info["name"] = name; if (!similar_bindings.empty()) { extra_info["candidates"] = StringUtil::Join(similar_bindings, ","); + return BinderException(extra_info, StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", + name, candidate_str)); + } else { + return BinderException( + extra_info, + StringUtil::Format("Referenced column \"%s\" was not found because the FROM clause is missing", name)); } - return BinderException( - StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", name, candidate_str), extra_info); } BinderException BinderException::NoMatchingFunction(const string &catalog_name, const string &schema_name, @@ -45,15 +49,14 @@ BinderException BinderException::NoMatchingFunction(const string &catalog_name, extra_info["candidates"] = StringUtil::Join(candidates, ","); } return BinderException( + extra_info, StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " "explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str), - extra_info); + call_str, candidate_str)); } BinderException BinderException::Unsupported(ParsedExpression &expr, const string &message) { auto extra_info = Exception::InitializeExtraInfo("UNSUPPORTED", expr.GetQueryLocation()); - return BinderException(message, extra_info); + return BinderException(extra_info, message); } - } // namespace duckdb diff --git a/src/duckdb/src/common/exception/catalog_exception.cpp b/src/duckdb/src/common/exception/catalog_exception.cpp index 5d890f1cd..b1cd4caf7 100644 --- a/src/duckdb/src/common/exception/catalog_exception.cpp +++ b/src/duckdb/src/common/exception/catalog_exception.cpp @@ -9,8 +9,8 @@ namespace duckdb { CatalogException::CatalogException(const string &msg) : Exception(ExceptionType::CATALOG, msg) { } -CatalogException::CatalogException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::CATALOG, msg, extra_info) { +CatalogException::CatalogException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::CATALOG, msg) { } CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion) { @@ -35,9 +35,9 @@ CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_in if (!suggestion.empty()) { extra_info["candidates"] = suggestion; } - return CatalogException(StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, - version_info, did_you_mean), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name %s does not exist%s!%s", CatalogTypeToString(type), name, + version_info, did_you_mean)); } CatalogException CatalogException::MissingEntry(CatalogType type, const string &name, const string &suggestion, @@ -55,17 +55,17 @@ CatalogException CatalogException::MissingEntry(const string &type, const string if (!suggestions.empty()) { extra_info["candidates"] = StringUtil::Join(suggestions, ", "); } - return CatalogException(StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, - StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean")), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, + StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean"))); } CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const string &name, QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("ENTRY_ALREADY_EXISTS", optional_idx()); extra_info["name"] = name; extra_info["type"] = CatalogTypeToString(type); - return CatalogException(StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name), - extra_info); + return CatalogException(extra_info, + StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name)); } } // namespace duckdb diff --git a/src/duckdb/src/common/exception/conversion_exception.cpp b/src/duckdb/src/common/exception/conversion_exception.cpp index 013dbdb9e..bf021b4eb 100644 --- a/src/duckdb/src/common/exception/conversion_exception.cpp +++ b/src/duckdb/src/common/exception/conversion_exception.cpp @@ -17,7 +17,7 @@ ConversionException::ConversionException(const string &msg) : Exception(Exceptio } ConversionException::ConversionException(optional_idx error_location, const string &msg) - : Exception(ExceptionType::CONVERSION, msg, Exception::InitializeExtraInfo(error_location)) { + : Exception(Exception::InitializeExtraInfo(error_location), ExceptionType::CONVERSION, msg) { } } // namespace duckdb diff --git a/src/duckdb/src/common/exception/parser_exception.cpp b/src/duckdb/src/common/exception/parser_exception.cpp index f3875da38..3afb2ea3d 100644 --- a/src/duckdb/src/common/exception/parser_exception.cpp +++ b/src/duckdb/src/common/exception/parser_exception.cpp @@ -7,13 +7,12 @@ namespace duckdb { ParserException::ParserException(const string &msg) : Exception(ExceptionType::PARSER, msg) { } -ParserException::ParserException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::PARSER, msg, extra_info) { +ParserException::ParserException(const unordered_map &extra_info, const string &msg) + : Exception(extra_info, ExceptionType::PARSER, msg) { } ParserException ParserException::SyntaxError(const string &query, const string &error_message, optional_idx error_location) { - return ParserException(error_message, Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location)); + return ParserException(Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location), error_message); } - } // namespace duckdb diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp index 51e34ec0e..27b4eb465 100644 --- a/src/duckdb/src/common/exception_format_value.cpp +++ b/src/duckdb/src/common/exception_format_value.cpp @@ -28,65 +28,61 @@ ExceptionFormatValue::ExceptionFormatValue(uhugeint_t uhuge_val) ExceptionFormatValue::ExceptionFormatValue(string str_val) : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { } -ExceptionFormatValue::ExceptionFormatValue(String str_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(str_val.ToStdString()) { +ExceptionFormatValue::ExceptionFormatValue(const String &str_val) : ExceptionFormatValue(str_val.ToStdString()) { } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value) { return ExceptionFormatValue(TypeIdToString(value)); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value) { return ExceptionFormatValue(value.ToString()); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value) { + return ExceptionFormatValue(static_cast(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { - return ExceptionFormatValue(double(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value) { - return ExceptionFormatValue(std::move(value)); +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value) { + return ExceptionFormatValue(value); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value) { return KeywordHelper::WriteQuoted(value.raw_string, '\''); } template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value) { return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value) { return ExceptionFormatValue(string(value)); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value) { return ExceptionFormatValue(value); } template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value) { +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value) { return ExceptionFormatValue(value); } diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index 1d3160814..6218f3e7b 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -507,4 +507,19 @@ shared_ptr TemplateTypeInfo::Copy() const { return make_shared_ptr(*this); } +//===--------------------------------------------------------------------===// +// Geo Type Info +//===--------------------------------------------------------------------===// +GeoTypeInfo::GeoTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::GEO_TYPE_INFO) { +} + +bool GeoTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + // No additional info to compare + return true; +} + +shared_ptr GeoTypeInfo::Copy() const { + return make_shared_ptr(*this); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp index 8e108ddc1..94223eed3 100644 --- a/src/duckdb/src/common/file_buffer.cpp +++ b/src/duckdb/src/common/file_buffer.cpp @@ -1,6 +1,6 @@ #include "duckdb/common/file_buffer.hpp" -#include "duckdb/common/allocator.hpp" +#include "duckdb/storage/block_allocator.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/helper.hpp" @@ -12,7 +12,7 @@ namespace duckdb { -FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size, idx_t block_header_size) +FileBuffer::FileBuffer(BlockAllocator &allocator, FileBufferType type, uint64_t user_size, idx_t block_header_size) : allocator(allocator), type(type) { Init(); if (user_size) { @@ -20,7 +20,7 @@ FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_ } } -FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, BlockManager &block_manager) +FileBuffer::FileBuffer(BlockAllocator &allocator, FileBufferType type, BlockManager &block_manager) : allocator(allocator), type(type) { Init(); Resize(block_manager); diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 926cfb6a0..782c30301 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -30,10 +30,7 @@ #include #ifdef __MVS__ -#define _XOPEN_SOURCE_EXTENDED 1 #include -// enjoy - https://reviews.llvm.org/D92110 -#define PATH_MAX _XOPEN_PATH_MAX #endif #else @@ -453,6 +450,10 @@ FileType FileSystem::GetFileType(FileHandle &handle) { return FileType::FILE_TYPE_INVALID; } +FileMetadata FileSystem::Stats(FileHandle &handle) { + throw NotImplementedException("%s: Stats is not implemented!", GetName()); +} + void FileSystem::Truncate(FileHandle &handle, int64_t new_size) { throw NotImplementedException("%s: Truncate is not implemented!", GetName()); } @@ -620,6 +621,10 @@ bool FileSystem::SubSystemIsDisabled(const string &name) { throw NotImplementedException("%s: Non-virtual file system does not have subsystems", GetName()); } +bool FileSystem::IsDisabledForPath(const string &path) { + throw NotImplementedException("%s: Non-virtual file system does not have subsystems", GetName()); +} + vector FileSystem::ListSubSystems() { throw NotImplementedException("%s: Can't list sub systems on a non-virtual file system", GetName()); } @@ -628,39 +633,9 @@ bool FileSystem::CanHandleFile(const string &fpath) { throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); } -static string LookupExtensionForPattern(const string &pattern) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(pattern, entry.name)) { - return entry.extension; - } - } - return ""; -} - vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, const FileGlobInput &input) { auto result = Glob(pattern); if (result.empty()) { - string required_extension = LookupExtensionForPattern(pattern); - if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { - auto &dbconfig = DBConfig::GetConfig(context); - if (!ExtensionHelper::CanAutoloadExtension(required_extension) || - !dbconfig.options.autoload_known_extensions) { - auto error_message = - "File " + pattern + " requires the extension " + required_extension + " to be loaded"; - error_message = - ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); - throw MissingExtensionException(error_message); - } - // an extension is required to read this file, but it is not loaded - try to load it - ExtensionHelper::AutoLoadExtension(context, required_extension); - // success! glob again - // check the extension is loaded just in case to prevent an infinite loop here - if (!context.db->ExtensionIsLoaded(required_extension)) { - throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", - required_extension); - } - return GlobFiles(pattern, context, input); - } if (input.behavior == FileGlobOptions::FALLBACK_GLOB && !HasGlob(pattern)) { // if we have no glob in the pattern and we have an extension, we try to glob if (!HasGlob(pattern)) { @@ -724,7 +699,7 @@ int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { int64_t FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesRead(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_READ, nr_bytes); } return file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes)); @@ -744,7 +719,7 @@ void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesRead(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_READ, nr_bytes); } file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); @@ -752,7 +727,7 @@ void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t void FileHandle::Write(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddBytesWritten(nr_bytes); + context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_WRITTEN, nr_bytes); } file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); @@ -830,6 +805,10 @@ FileType FileHandle::GetType() { return file_system.GetFileType(*this); } +FileMetadata FileHandle::Stats() { + return file_system.Stats(*this); +} + void FileHandle::TryAddLogger(FileOpener &opener) { if (flags.DisableLogging()) { return; diff --git a/src/duckdb/src/common/gzip_file_system.cpp b/src/duckdb/src/common/gzip_file_system.cpp index 92b4e10d2..c3ee6d27d 100644 --- a/src/duckdb/src/common/gzip_file_system.cpp +++ b/src/duckdb/src/common/gzip_file_system.cpp @@ -93,7 +93,19 @@ MiniZStreamWrapper::~MiniZStreamWrapper() { } try { MiniZStreamWrapper::Close(); - } catch (...) { // NOLINT - cannot throw in exception + } catch (std::exception &ex) { + if (file && file->child_handle) { + // FIXME: Make any log context available here. + ErrorData data(ex); + try { + const auto logger = file->child_handle->logger; + if (logger) { + DUCKDB_LOG_ERROR(logger, "MiniZStreamWrapper::~MiniZStreamWrapper()\t\t" + data.Message()) + } + } catch (...) { // NOLINT + } + } + } catch (...) { // NOLINT } } diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp index 932943b8f..78f3b40e8 100644 --- a/src/duckdb/src/common/hive_partitioning.cpp +++ b/src/duckdb/src/common/hive_partitioning.cpp @@ -153,7 +153,6 @@ void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector> &filters, const HivePartitioningFilterInfo &filter_info, MultiFilePushdownInfo &info) { - vector pruned_files; vector have_preserved_filter(filters.size(), false); vector> pruned_filters; diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index 8733e0162..3d9cae503 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -167,29 +167,45 @@ struct UnixFileHandle : public FileHandle { }; }; -static FileType GetFileTypeInternal(int fd) { // LCOV_EXCL_START +static FileMetadata StatsInternal(int fd, const string &path) { struct stat s; if (fstat(fd, &s) == -1) { - return FileType::FILE_TYPE_INVALID; + throw IOException({{"errno", std::to_string(errno)}}, "Failed to get stats for file \"%s\": %s", path, + strerror(errno)); } + + FileMetadata file_metadata; + file_metadata.file_size = s.st_size; + file_metadata.last_modification_time = Timestamp::FromEpochSeconds(s.st_mtime); + switch (s.st_mode & S_IFMT) { case S_IFBLK: - return FileType::FILE_TYPE_BLOCKDEV; + file_metadata.file_type = FileType::FILE_TYPE_BLOCKDEV; + break; case S_IFCHR: - return FileType::FILE_TYPE_CHARDEV; + file_metadata.file_type = FileType::FILE_TYPE_CHARDEV; + break; case S_IFIFO: - return FileType::FILE_TYPE_FIFO; + file_metadata.file_type = FileType::FILE_TYPE_FIFO; + break; case S_IFDIR: - return FileType::FILE_TYPE_DIR; + file_metadata.file_type = FileType::FILE_TYPE_DIR; + break; case S_IFLNK: - return FileType::FILE_TYPE_LINK; + file_metadata.file_type = FileType::FILE_TYPE_LINK; + break; case S_IFREG: - return FileType::FILE_TYPE_REGULAR; + file_metadata.file_type = FileType::FILE_TYPE_REGULAR; + break; case S_IFSOCK: - return FileType::FILE_TYPE_SOCKET; + file_metadata.file_type = FileType::FILE_TYPE_SOCKET; + break; default: - return FileType::FILE_TYPE_INVALID; + file_metadata.file_type = FileType::FILE_TYPE_INVALID; + break; } + + return file_metadata; } // LCOV_EXCL_STOP #if __APPLE__ && !TARGET_OS_IPHONE @@ -369,7 +385,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (flags.ReturnNullIfExists() && errno == EEXIST) { return nullptr; } - throw IOException("Cannot open file \"%s\": %s", {{"errno", std::to_string(errno)}}, path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, "Cannot open file \"%s\": %s", path, strerror(errno)); } #if defined(__DARWIN__) || defined(__APPLE__) @@ -385,7 +401,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (flags.Lock() != FileLockType::NO_LOCK) { // set lock on file // but only if it is not an input/output stream - auto file_type = GetFileTypeInternal(fd); + auto file_type = StatsInternal(fd, path_p).file_type; if (file_type != FileType::FILE_TYPE_FIFO && file_type != FileType::FILE_TYPE_SOCKET) { struct flock fl; memset(&fl, 0, sizeof fl); @@ -436,7 +452,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF extended_error += ". Also, failed closing file"; } extended_error += ". See also https://duckdb.org/docs/stable/connect/concurrency"; - throw IOException("Could not set lock on file \"%s\": %s", {{"errno", std::to_string(retained_errno)}}, + throw IOException({{"errno", std::to_string(retained_errno)}}, "Could not set lock on file \"%s\": %s", path, extended_error); } } @@ -454,7 +470,7 @@ void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { int fd = handle.Cast().fd; off_t offset = lseek(fd, UnsafeNumericCast(location), SEEK_SET); if (offset == (off_t)-1) { - throw IOException("Could not seek to location %lld for file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not seek to location %lld for file \"%s\": %s", location, handle.path, strerror(errno)); } } @@ -463,7 +479,7 @@ idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { int fd = handle.Cast().fd; off_t position = lseek(fd, 0, SEEK_CUR); if (position == (off_t)-1) { - throw IOException("Could not get file position file \"%s\": %s", {{"errno", std::to_string(errno)}}, + throw IOException({{"errno", std::to_string(errno)}}, "Could not get file position file \"%s\": %s", handle.path, strerror(errno)); } return UnsafeNumericCast(position); @@ -477,7 +493,7 @@ void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, i int64_t bytes_read = pread(fd, read_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_read == 0) { @@ -498,7 +514,7 @@ int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes int fd = unix_handle.fd; int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not read from file \"%s\": %s", handle.path, strerror(errno)); } @@ -519,12 +535,13 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, int64_t bytes_written = pwrite(fd, write_buffer, UnsafeNumericCast(bytes_to_write), UnsafeNumericCast(current_location)); if (bytes_written < 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } if (bytes_written == 0) { - throw IOException("Could not write to file \"%s\" - attempted to write 0 bytes: %s", - {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); + throw IOException({{"errno", std::to_string(errno)}}, + "Could not write to file \"%s\" - attempted to write 0 bytes: %s", handle.path, + strerror(errno)); } write_buffer += bytes_written; bytes_to_write -= bytes_written; @@ -544,7 +561,7 @@ int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_byte MinValue(idx_t(NumericLimits::Maximum()), idx_t(bytes_to_write)); int64_t current_bytes_written = write(fd, buffer, bytes_to_write_this_call); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } buffer = (void *)(data_ptr_cast(buffer) + current_bytes_written); @@ -574,34 +591,30 @@ bool LocalFileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_ } int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - throw IOException("Failed to get file size for file \"%s\": %s", {{"errno", std::to_string(errno)}}, - handle.path, strerror(errno)); - } - return s.st_size; + const auto file_metadata = Stats(handle); + return file_metadata.file_size; } timestamp_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - throw IOException("Failed to get last modified time for file \"%s\": %s", {{"errno", std::to_string(errno)}}, - handle.path, strerror(errno)); - } - return Timestamp::FromEpochSeconds(s.st_mtime); + const auto file_metadata = Stats(handle); + return file_metadata.last_modification_time; } FileType LocalFileSystem::GetFileType(FileHandle &handle) { + const auto file_metadata = Stats(handle); + return file_metadata.file_type; +} + +FileMetadata LocalFileSystem::Stats(FileHandle &handle) { int fd = handle.Cast().fd; - return GetFileTypeInternal(fd); + auto file_metadata = StatsInternal(fd, handle.GetPath()); + return file_metadata; } void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { int fd = handle.Cast().fd; if (ftruncate(fd, new_size) != 0) { - throw IOException("Could not truncate file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not truncate file \"%s\": %s", handle.path, strerror(errno)); } } @@ -612,7 +625,7 @@ bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { auto normalized_file = NormalizeLocalPath(filename); if (std::remove(normalized_file) != 0) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", std::to_string(errno)}}, filename, + throw IOException({{"errno", std::to_string(errno)}}, "Could not remove file \"%s\": %s", filename, strerror(errno)); } } @@ -718,7 +731,7 @@ bool LocalFileSystem::ListFilesExtended(const string &directory, if (res != 0) { continue; } - if (!(status.st_mode & S_IFREG) && !(status.st_mode & S_IFDIR)) { + if (!S_ISREG(status.st_mode) && !S_ISDIR(status.st_mode)) { // not a file or directory: skip continue; } @@ -726,7 +739,7 @@ bool LocalFileSystem::ListFilesExtended(const string &directory, info.extended_info = make_shared_ptr(); auto &options = info.extended_info->options; // file type - Value file_type(status.st_mode & S_IFDIR ? "directory" : "file"); + Value file_type(S_ISDIR(status.st_mode) ? "directory" : "file"); options.emplace("type", std::move(file_type)); // file size options.emplace("file_size", Value::BIGINT(UnsafeNumericCast(status.st_size))); @@ -767,8 +780,7 @@ void LocalFileSystem::FileSync(FileHandle &handle) { } // For other types of errors, throw normal IO exception. - throw IOException("Could not fsync file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.GetPath(), - strerror(errno)); + throw IOException("Could not fsync file \"%s\": %s", handle.GetPath(), strerror(errno)); } void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -776,7 +788,7 @@ void LocalFileSystem::MoveFile(const string &source, const string &target, optio auto normalized_target = NormalizeLocalPath(target); //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists if (rename(normalized_source, normalized_target) != 0) { - throw IOException("Could not rename file!", {{"errno", std::to_string(errno)}}); + throw IOException({{"errno", std::to_string(errno)}}, "Could not rename file!"); } } @@ -814,6 +826,72 @@ std::string LocalFileSystem::GetLastErrorAsString() { return message; } +static timestamp_t FiletimeToTimeStamp(FILETIME file_time) { + // https://stackoverflow.com/questions/29266743/what-is-dwlowdatetime-and-dwhighdatetime + ULARGE_INTEGER ul; + ul.LowPart = file_time.dwLowDateTime; + ul.HighPart = file_time.dwHighDateTime; + int64_t fileTime64 = ul.QuadPart; + + // fileTime64 contains a 64-bit value representing the number of + // 100-nanosecond intervals since January 1, 1601 (UTC). + // https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime + + // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux + const auto WINDOWS_TICK = 10000000; + const auto SEC_TO_UNIX_EPOCH = 11644473600LL; + return Timestamp::FromTimeT(fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); +} + +static FileMetadata StatsInternal(HANDLE hFile, const string &path) { + FileMetadata file_metadata; + + DWORD handle_type = GetFileType(hFile); + if (handle_type == FILE_TYPE_CHAR) { + file_metadata.file_type = FileType::FILE_TYPE_CHARDEV; + file_metadata.file_size = 0; + file_metadata.last_modification_time = Timestamp::FromTimeT(0); + return file_metadata; + } + if (handle_type == FILE_TYPE_PIPE) { + file_metadata.file_type = FileType::FILE_TYPE_FIFO; + file_metadata.file_size = 0; + file_metadata.last_modification_time = Timestamp::FromTimeT(0); + return file_metadata; + } + + BY_HANDLE_FILE_INFORMATION file_info; + if (!GetFileInformationByHandle(hFile, &file_info)) { + auto error = LocalFileSystem::GetLastErrorAsString(); + throw IOException("Failed to get stats for file \"%s\": %s", path, error); + } + + // Get file size from high and low parts. + file_metadata.file_size = + (static_cast(file_info.nFileSizeHigh) << 32) | static_cast(file_info.nFileSizeLow); + + // Get last modification time + file_metadata.last_modification_time = FiletimeToTimeStamp(file_info.ftLastWriteTime); + + // Get file type from attributes + if (strncmp(path.c_str(), PIPE_PREFIX, strlen(PIPE_PREFIX)) == 0) { + // pipes in windows are just files in '\\.\pipe\' folder + file_metadata.file_type = FileType::FILE_TYPE_FIFO; + } else if (file_info.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { + file_metadata.file_type = FileType::FILE_TYPE_DIR; + } else if (file_info.dwFileAttributes & FILE_ATTRIBUTE_DEVICE) { + file_metadata.file_type = FileType::FILE_TYPE_CHARDEV; + } else if (file_info.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) { + file_metadata.file_type = FileType::FILE_TYPE_LINK; + } else if (file_info.dwFileAttributes != INVALID_FILE_ATTRIBUTES) { + file_metadata.file_type = FileType::FILE_TYPE_REGULAR; + } else { + file_metadata.file_type = FileType::FILE_TYPE_INVALID; + } + + return file_metadata; +} + struct WindowsFileHandle : public FileHandle { public: WindowsFileHandle(FileSystem &file_system, string path, HANDLE fd, FileOpenFlags flags) @@ -943,6 +1021,10 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF default: throw InternalException("Unknown FileLockType"); } + // For windows platform, by default deletion fails when the file is accessed by other thread/process. + // To keep deletion behavior compatible with unix platform, which physically deletes a file when reference count + // drops to 0 without interfering with already opened file handles, open files with [`FILE_SHARE_DELETE`]. + share_mode |= FILE_SHARE_DELETE; if (open_write) { if (flags.CreateFileIfNotExists()) { @@ -1052,7 +1134,7 @@ static int64_t FSWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t n auto bytes_to_write = MinValue(idx_t(NumericLimits::Maximum()), idx_t(nr_bytes)); DWORD current_bytes_written = FSInternalWrite(handle, hFile, buffer, bytes_to_write, location); if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, + throw IOException({{"errno", std::to_string(errno)}}, "Could not write file \"%s\": %s", handle.path, strerror(errno)); } bytes_written += current_bytes_written; @@ -1088,42 +1170,13 @@ bool LocalFileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_ } int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - LARGE_INTEGER result; - if (!GetFileSizeEx(hFile, &result)) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to get file size for file \"%s\": %s", handle.path, error); - } - return result.QuadPart; -} - -timestamp_t FiletimeToTimeStamp(FILETIME file_time) { - // https://stackoverflow.com/questions/29266743/what-is-dwlowdatetime-and-dwhighdatetime - ULARGE_INTEGER ul; - ul.LowPart = file_time.dwLowDateTime; - ul.HighPart = file_time.dwHighDateTime; - int64_t fileTime64 = ul.QuadPart; - - // fileTime64 contains a 64-bit value representing the number of - // 100-nanosecond intervals since January 1, 1601 (UTC). - // https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime - - // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux - const auto WINDOWS_TICK = 10000000; - const auto SEC_TO_UNIX_EPOCH = 11644473600LL; - return Timestamp::FromTimeT(fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); + const auto file_metadata = Stats(handle); + return file_metadata.file_size; } timestamp_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - - // https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfiletime - FILETIME last_write; - if (GetFileTime(hFile, nullptr, nullptr, &last_write) == 0) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to get last modified time for file \"%s\": %s", handle.path, error); - } - return FiletimeToTimeStamp(last_write); + const auto file_metadata = Stats(handle); + return file_metadata.last_modification_time; } void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { @@ -1250,28 +1303,22 @@ void LocalFileSystem::FileSync(FileHandle &handle) { void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { auto source_unicode = NormalizePathAndConvertToUnicode(source); auto target_unicode = NormalizePathAndConvertToUnicode(target); + DWORD flags = MOVEFILE_REPLACE_EXISTING | MOVEFILE_WRITE_THROUGH; - if (!MoveFileW(source_unicode.c_str(), target_unicode.c_str())) { + if (!MoveFileExW(source_unicode.c_str(), target_unicode.c_str(), flags)) { throw IOException("Could not move file: %s", GetLastErrorAsString()); } } FileType LocalFileSystem::GetFileType(FileHandle &handle) { - auto path = handle.Cast().path; - // pipes in windows are just files in '\\.\pipe\' folder - if (strncmp(path.c_str(), PIPE_PREFIX, strlen(PIPE_PREFIX)) == 0) { - return FileType::FILE_TYPE_FIFO; - } - auto normalized_path = NormalizePathAndConvertToUnicode(path); - DWORD attrs = WindowsGetFileAttributes(normalized_path); - if (attrs != INVALID_FILE_ATTRIBUTES) { - if (attrs & FILE_ATTRIBUTE_DIRECTORY) { - return FileType::FILE_TYPE_DIR; - } else { - return FileType::FILE_TYPE_REGULAR; - } - } - return FileType::FILE_TYPE_INVALID; + const auto file_metadata = Stats(handle); + return file_metadata.file_type; +} + +FileMetadata LocalFileSystem::Stats(FileHandle &handle) { + HANDLE hFile = handle.Cast().fd; + auto file_metadata = StatsInternal(hFile, handle.GetPath()); + return file_metadata; } #endif @@ -1283,6 +1330,28 @@ bool LocalFileSystem::OnDiskFile(FileHandle &handle) { return true; } +string LocalFileSystem::GetVersionTag(FileHandle &handle) { + // TODO: Fix using FileSystem::Stats for v1.5, which should also fix it for Windows +#ifdef _WIN32 + return ""; +#else + int fd = handle.Cast().fd; + struct stat s; + if (fstat(fd, &s) == -1) { + throw IOException("Failed to get file size for file \"%s\": %s", handle.path, strerror(errno)); + } + + // dev/ino should be enough, but to guard against in-place writes we also add file size and modification time + uint64_t version_tag[4]; + Store(NumericCast(s.st_dev), data_ptr_cast(&version_tag[0])); + Store(NumericCast(s.st_ino), data_ptr_cast(&version_tag[1])); + Store(NumericCast(s.st_size), data_ptr_cast(&version_tag[2])); + Store(Timestamp::FromEpochSeconds(s.st_mtime).value, data_ptr_cast(&version_tag[3])); + + return string(char_ptr_cast(version_tag), sizeof(uint64_t) * 4); +#endif +} + void LocalFileSystem::Seek(FileHandle &handle, idx_t location) { if (!CanSeek()) { throw IOException("Cannot seek in files of this type"); @@ -1319,7 +1388,6 @@ static bool IsSymbolicLink(const string &path) { static void RecursiveGlobDirectories(FileSystem &fs, const string &path, vector &result, bool match_directory, bool join_path) { - fs.ListFiles(path, [&](OpenFileInfo &info) { if (join_path) { info.path = fs.JoinPath(path, info.path); diff --git a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp index c4537738d..b90bbd691 100644 --- a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp @@ -602,7 +602,7 @@ unique_ptr ConstructMapExpression(ClientContext &context, idx_t loca children.push_back(std::move(mapping.default_value)); } auto remap_fun = RemapStructFun::GetFunction(); - auto bind_data = remap_fun.bind(context, remap_fun, children); + auto bind_data = remap_fun.GetBindCallback()(context, remap_fun, children); children[0] = BoundCastExpression::AddCastToType(context, std::move(children[0]), remap_fun.arguments[0]); return make_uniq(global_column.type, std::move(remap_fun), std::move(children), std::move(bind_data)); @@ -866,6 +866,10 @@ bool MultiFileColumnMapper::EvaluateFilterAgainstConstant(TableFilter &filter, c auto &expr_filter = filter.Cast(); return expr_filter.EvaluateWithConstant(context, constant); } + case TableFilterType::BLOOM_FILTER: { + auto &bloom_filter = filter.Cast(); + return bloom_filter.FilterValue(constant); + } default: throw NotImplementedException("Can't evaluate TableFilterType (%s) against a constant", EnumUtil::ToString(type)); diff --git a/src/duckdb/src/common/multi_file/multi_file_reader.cpp b/src/duckdb/src/common/multi_file/multi_file_reader.cpp index 21413261e..d69e4a187 100644 --- a/src/duckdb/src/common/multi_file/multi_file_reader.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_reader.cpp @@ -11,6 +11,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/multi_file/multi_file_function.hpp" #include "duckdb/common/multi_file/union_by_name.hpp" @@ -299,7 +300,6 @@ void MultiFileReader::FinalizeBind(MultiFileReaderData &reader_data, const Multi const vector &global_columns, const vector &global_column_ids, ClientContext &context, optional_ptr global_state) { - // create a map of name -> column index auto &local_columns = reader_data.reader->GetColumns(); auto &filename = reader_data.reader->GetFileName(); diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp index f26c16131..5998d7787 100644 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -19,6 +19,7 @@ #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types.hpp" #include "fast_float/fast_float.h" #include "duckdb/common/types/bit.hpp" @@ -1406,7 +1407,6 @@ string_t CastFromBlobToBit::Operation(string_t input, Vector &vector) { //===--------------------------------------------------------------------===// template <> string_t CastFromBitToString::Operation(string_t input, Vector &vector) { - idx_t result_size = Bit::BitLength(input); string_t result = StringVector::EmptyString(vector, result_size); Bit::ToString(input, result.GetDataWriteable()); @@ -1560,6 +1560,14 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict return true; } +//===--------------------------------------------------------------------===// +// Cast To Geometry +//===--------------------------------------------------------------------===// +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters) { + return Geometry::FromString(input, result, result_vector, parameters.strict); +} + //===--------------------------------------------------------------------===// // Cast To Date //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/pipe_file_system.cpp b/src/duckdb/src/common/pipe_file_system.cpp index dc9b7d108..6f0a29741 100644 --- a/src/duckdb/src/common/pipe_file_system.cpp +++ b/src/duckdb/src/common/pipe_file_system.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/local_file_system.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/main/client_context.hpp" @@ -53,6 +54,11 @@ int64_t PipeFileSystem::GetFileSize(FileHandle &handle) { return 0; } +timestamp_t PipeFileSystem::GetLastModifiedTime(FileHandle &handle) { + auto &child_handle = *handle.Cast().child_handle; + return child_handle.file_system.GetLastModifiedTime(child_handle); +} + void PipeFileSystem::FileSync(FileHandle &handle) { } diff --git a/src/duckdb/src/common/printer.cpp b/src/duckdb/src/common/printer.cpp index e07d3f8f2..46187ac24 100644 --- a/src/duckdb/src/common/printer.cpp +++ b/src/duckdb/src/common/printer.cpp @@ -31,8 +31,7 @@ void Printer::RawPrint(OutputStream stream, const string &str) { } void Printer::DefaultLinePrint(OutputStream stream, const string &str) { - Printer::RawPrint(stream, str); - Printer::RawPrint(stream, "\n"); + Printer::RawPrint(stream, str + "\n"); } line_printer_f Printer::line_printer = Printer::DefaultLinePrint; diff --git a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp index 417ac609a..50b44bf97 100644 --- a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp +++ b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp @@ -14,7 +14,7 @@ int32_t TerminalProgressBarDisplay::NormalizePercentage(double percentage) { return int32_t(percentage); } -static string FormatETA(double seconds, bool elapsed = false) { +string TerminalProgressBarDisplay::FormatETA(double seconds, bool elapsed) { // for terminal rendering purposes, we need to make sure the length is always the same // we pad the end with spaces if that is not the case // the maximum length here is "(~10.35 minutes remaining)" (26 bytes) @@ -68,14 +68,38 @@ static string FormatETA(double seconds, bool elapsed = false) { return result; } -void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, double seconds, bool finished) { - string result; +string TerminalProgressBarDisplay::FormatProgressBar(const ProgressBarDisplayInfo &display, int32_t percentage) { // we divide the number of blocks by the percentage // 0% = 0 // 100% = PROGRESS_BAR_WIDTH // the percentage determines how many blocks we need to draw - double blocks_to_draw = PROGRESS_BAR_WIDTH * (percentage / 100.0); + double blocks_to_draw = static_cast(display.width) * (percentage / 100.0); // because of the power of unicode, we can also draw partial blocks + string result; + result += display.progress_start; + idx_t i; + for (i = 0; i < idx_t(blocks_to_draw); i++) { + result += display.progress_block; + } + if (i < display.width) { + // print a partial block based on the percentage of the progress bar remaining + idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * + static_cast(display.partial_block_count)); + if (index >= display.partial_block_count) { + index = display.partial_block_count - 1; + } + result += display.progress_partial[index]; + i++; + } + for (; i < display.width; i++) { + result += display.progress_empty; + } + result += display.progress_end; + return result; +} + +void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, double seconds, bool finished) { + string result; // render the percentage with some padding to ensure everything stays nicely aligned result = "\r"; @@ -87,24 +111,7 @@ void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage, doubl } result += to_string(percentage) + "%"; result += " "; - result += PROGRESS_START; - idx_t i; - for (i = 0; i < idx_t(blocks_to_draw); i++) { - result += PROGRESS_BLOCK; - } - if (i < PROGRESS_BAR_WIDTH) { - // print a partial block based on the percentage of the progress bar remaining - idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * PARTIAL_BLOCK_COUNT); - if (index >= PARTIAL_BLOCK_COUNT) { - index = PARTIAL_BLOCK_COUNT - 1; - } - result += PROGRESS_PARTIAL[index]; - i++; - } - for (; i < PROGRESS_BAR_WIDTH; i++) { - result += PROGRESS_EMPTY; - } - result += PROGRESS_END; + result += FormatProgressBar(display_info, percentage); result += " "; result += FormatETA(seconds, finished); diff --git a/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp b/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp index 551cdbb0f..158d59fd7 100644 --- a/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp +++ b/src/duckdb/src/common/progress_bar/unscented_kalman_filter.cpp @@ -6,7 +6,6 @@ UnscentedKalmanFilter::UnscentedKalmanFilter() : x(STATE_DIM, 0.0), P(STATE_DIM, vector(STATE_DIM, 0.0)), Q(STATE_DIM, vector(STATE_DIM, 0.0)), R(OBS_DIM, vector(OBS_DIM, 0.0)), last_time(0.0), initialized(false), last_progress(-1.0), scale_factor(1.0) { - // Calculate UKF parameters lambda = ALPHA * ALPHA * (STATE_DIM + KAPPA) - STATE_DIM; @@ -254,11 +253,11 @@ void UnscentedKalmanFilter::UpdateInternal(double measured_progress) { } // Ensure progress stays in bounds - x[0] = std::max(0.0, std::min(1.0, x[0])); + x[0] = std::max(0.0, std::min(scale_factor, x[0])); } double UnscentedKalmanFilter::GetProgress() const { - return x[0]; + return x[0] / scale_factor; } double UnscentedKalmanFilter::GetVelocity() const { diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp index 487e106af..6b4ead263 100644 --- a/src/duckdb/src/common/radix_partitioning.cpp +++ b/src/duckdb/src/common/radix_partitioning.cpp @@ -98,6 +98,7 @@ struct ComputePartitionIndicesFunctor { const auto source_data = UnifiedVectorFormat::GetData(format); const auto &source_sel = *format.sel; + partition_indices.SetVectorType(VectorType::FLAT_VECTOR); const auto target = FlatVector::GetData(partition_indices); if (source_sel.IsSet()) { @@ -169,16 +170,16 @@ void RadixPartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAp // Tuple Data Partitioning //===--------------------------------------------------------------------===// RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manager, - shared_ptr layout_ptr, const idx_t radix_bits_p, - const idx_t hash_col_idx_p) - : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_ptr), radix_bits(radix_bits_p), + shared_ptr layout_ptr, const MemoryTag tag, + const idx_t radix_bits_p, const idx_t hash_col_idx_p) + : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_ptr, tag), radix_bits(radix_bits_p), hash_col_idx(hash_col_idx_p) { D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); D_ASSERT(hash_col_idx < layout.GetTypes().size()); Initialize(); } -RadixPartitionedTupleData::RadixPartitionedTupleData(const RadixPartitionedTupleData &other) +RadixPartitionedTupleData::RadixPartitionedTupleData(RadixPartitionedTupleData &other) : PartitionedTupleData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { Initialize(); } @@ -189,7 +190,7 @@ RadixPartitionedTupleData::~RadixPartitionedTupleData() { void RadixPartitionedTupleData::Initialize() { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); for (idx_t i = 0; i < num_partitions; i++) { - partitions.emplace_back(CreatePartitionCollection(i)); + partitions.emplace_back(CreatePartitionCollection()); partitions.back()->SetPartitionIndex(i); } } diff --git a/src/duckdb/src/common/random_engine.cpp b/src/duckdb/src/common/random_engine.cpp index 78403e030..156b4baec 100644 --- a/src/duckdb/src/common/random_engine.cpp +++ b/src/duckdb/src/common/random_engine.cpp @@ -82,4 +82,14 @@ void RandomEngine::SetSeed(uint64_t seed) { random_state->pcg.seed(seed); } +void RandomEngine::RandomData(duckdb::data_ptr_t data, duckdb::idx_t len) { + while (len) { + const auto random_integer = NextRandomInteger(); + const auto next = duckdb::MinValue(len, sizeof(random_integer)); + memcpy(data, duckdb::const_data_ptr_cast(&random_integer), next); + data += next; + len -= next; + } +} + } // namespace duckdb diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp index 582d5e1ad..f3bb9d54a 100644 --- a/src/duckdb/src/common/render_tree.cpp +++ b/src/duckdb/src/common/render_tree.cpp @@ -102,22 +102,22 @@ static unique_ptr CreateNode(const PipelineRenderNode &op) { static unique_ptr CreateNode(const ProfilingNode &op) { auto &info = op.GetProfilingInfo(); InsertionOrderPreservingMap extra_info; - if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - extra_info = op.GetProfilingInfo().extra_info; + if (info.Enabled(info.settings, MetricType::EXTRA_INFO)) { + extra_info = op.GetProfilingInfo().GetMetricValue>(MetricType::EXTRA_INFO); } string node_name = "QUERY"; if (op.depth > 0) { - node_name = info.GetMetricAsString(MetricsType::OPERATOR_TYPE); + node_name = info.GetMetricAsString(MetricType::OPERATOR_TYPE); } auto result = make_uniq(node_name, extra_info); - if (info.Enabled(info.settings, MetricsType::OPERATOR_CARDINALITY)) { - auto cardinality = info.GetMetricAsString(MetricsType::OPERATOR_CARDINALITY); + if (info.Enabled(info.settings, MetricType::OPERATOR_CARDINALITY)) { + auto cardinality = info.GetMetricAsString(MetricType::OPERATOR_CARDINALITY); result->extra_text[RenderTreeNode::CARDINALITY] = cardinality; } - if (info.Enabled(info.settings, MetricsType::OPERATOR_TIMING)) { - auto value = info.metrics.at(MetricsType::OPERATOR_TIMING).GetValue(); + if (info.Enabled(info.settings, MetricType::OPERATOR_TIMING)) { + auto value = info.metrics.at(MetricType::OPERATOR_TIMING).GetValue(); string timing = StringUtil::Format("%.2f", value); result->extra_text[RenderTreeNode::TIMING] = timing + "s"; } diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp index 73fa81de2..6d53b9d8a 100644 --- a/src/duckdb/src/common/row_operations/row_aggregate.cpp +++ b/src/duckdb/src/common/row_operations/row_aggregate.cpp @@ -17,11 +17,12 @@ void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, for (const auto &aggr : layout.GetAggregates()) { if (sel.IsSet()) { for (idx_t i = 0; i < count; ++i) { - aggr.function.initialize(aggr.function, pointers[sel.get_index_unsafe(i)] + offsets[aggr_idx]); + aggr.function.GetStateInitCallback()(aggr.function, + pointers[sel.get_index_unsafe(i)] + offsets[aggr_idx]); } } else { for (idx_t i = 0; i < count; ++i) { - aggr.function.initialize(aggr.function, pointers[i] + offsets[aggr_idx]); + aggr.function.GetStateInitCallback()(aggr.function, pointers[i] + offsets[aggr_idx]); } } ++aggr_idx; @@ -35,9 +36,9 @@ void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &la // Move to the first aggregate state VectorOperations::AddInPlace(addresses, UnsafeNumericCast(layout.GetAggrOffset()), count); for (const auto &aggr : layout.GetAggregates()) { - if (aggr.function.destructor) { + if (aggr.function.HasStateDestructorCallback()) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.destructor(addresses, aggr_input_data, count); + aggr.function.GetStateDestructorCallback()(addresses, aggr_input_data, count); } // Move to the next aggregate state VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggr.payload_size), count); @@ -47,8 +48,8 @@ void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &la void RowOperations::UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx, idx_t count) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.update(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, aggr.child_count, - addresses, count); + aggr.function.GetStateUpdateCallback()(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, + aggr.child_count, addresses, count); } void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, @@ -78,10 +79,10 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la idx_t offset = layout.GetAggrOffset(); for (auto &aggr : layout.GetAggregates()) { - D_ASSERT(aggr.function.combine); + D_ASSERT(aggr.function.HasStateCombineCallback()); AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - aggr.function.combine(sources, targets, aggr_input_data, count); + aggr.function.GetStateCombineCallback()(sources, targets, aggr_input_data, count); // Move to the next aggregate states VectorOperations::AddInPlace(sources, UnsafeNumericCast(aggr.payload_size), count); @@ -113,7 +114,7 @@ void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &l auto &target = result.data[aggr_idx + i]; auto &aggr = aggregates[i]; AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.finalize(addresses_copy, aggr_input_data, target, result.size(), 0); + aggr.function.GetStateFinalizeCallback()(addresses_copy, aggr_input_data, target, result.size(), 0); // Move to the next aggregate state VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(aggr.payload_size), result.size()); diff --git a/src/duckdb/src/common/row_operations/row_external.cpp b/src/duckdb/src/common/row_operations/row_external.cpp deleted file mode 100644 index e4e3ec87d..000000000 --- a/src/duckdb/src/common/row_operations/row_external.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), - string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp deleted file mode 100644 index 8e5ed315b..000000000 --- a/src/duckdb/src/common/row_operations/row_gather.cpp +++ /dev/null @@ -1,176 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/constant_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array_uninitialized(count); - auto mask_locations = make_unsafe_uniq_array_uninitialized(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - NestedValidity parent_validity(mask_locations.get(), col_no); - RowOperations::HeapGather(col, count, col_sel, data_locations.get(), &parent_validity); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_gather.cpp b/src/duckdb/src/common/row_operations/row_heap_gather.cpp deleted file mode 100644 index fa433c64e..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_gather.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - NestedValidity parent_validity(struct_validitymask_locations, i); - RowOperations::HeapGather(*children[i], vcount, sel, key_locations, &parent_validity); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapGatherArrayVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // Setup - auto &child_type = ArrayType::GetChildType(v.GetType()); - auto array_size = ArrayType::GetSize(v.GetType()); - auto &child_vector = ArrayVector::GetEntry(v); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < vcount; i++) { - // Setup validity mask - data_ptr_t array_validitymask_location = key_locations[i]; - key_locations[i] += array_validitymask_size; - - NestedValidity parent_validity(array_validitymask_location); - - // The size of each variable size entry is stored after the validity mask - // (if the child type is variable size) - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // row idx - const auto row_idx = sel.get_index(i); - - idx_t array_start = row_idx * array_size; - idx_t elem_remaining = array_size; - - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - SelectionVector array_sel(STANDARD_VECTOR_SIZE); - - if (child_type_is_var_size) { - // variable size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } else { - // constant size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } - - // Pass on this array's validity mask to the child vector - RowOperations::HeapGather(child_vector, chunk_size, array_sel, array_entry_locations, &parent_validity); - - elem_remaining -= chunk_size; - array_start += chunk_size; - parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t *key_locations, - optional_ptr parent_validity) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (parent_validity) { - for (idx_t i = 0; i < vcount; i++) { - const auto valid = parent_validity->IsValid(i); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - case PhysicalType::ARRAY: - HeapGatherArrayVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp deleted file mode 100644 index 01cf7b589..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp +++ /dev/null @@ -1,581 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -NestedValidity::NestedValidity(data_ptr_t validitymask_location) - : list_validity_location(validitymask_location), struct_validity_locations(nullptr), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { -} - -NestedValidity::NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index) - : list_validity_location(nullptr), struct_validity_locations(validitymask_locations), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { - ValidityBytes::GetEntryIndex(child_vector_index, entry_idx, idx_in_entry); -} - -void NestedValidity::SetInvalid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = ~(1UL << list_idx_in_entry); - list_validity_location[list_entry_idx] &= bit; - } else { - // Is Struct - const auto bit = ~(1UL << idx_in_entry); - *(struct_validity_locations[idx] + entry_idx) &= bit; - } -} - -void NestedValidity::OffsetListBy(idx_t offset) { - list_validity_offset += offset; -} - -bool NestedValidity::IsValid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = (1UL << list_idx_in_entry); - return list_validity_location[list_entry_idx] & bit; - } else { - // Is Struct - const auto bit = (1UL << idx_in_entry); - return *(struct_validity_locations[idx] + entry_idx) & bit; - } -} - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -static void ComputeArrayEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_vector = ArrayVector::GetEntry(v); - - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - const idx_t array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - - // Validity for the array elements - entry_sizes[i] += array_validitymask_size; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ArrayType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += array_size * sizeof(idx_t); - } - - auto elem_idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(elem_idx + offset); - - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - // the array could span multiple vectors, so we divide it into chunks - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // compute and add to the total - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t arr_elem_idx = 0; arr_elem_idx < chunk_size; arr_elem_idx++) { - entry_sizes[i] += array_entry_sizes[arr_elem_idx]; - } - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::ARRAY: - ComputeArrayEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - NestedValidity struct_validity(struct_validitymask_locations, i); - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, key_locations, &struct_validity, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (parent_validity) { - // set the row validitymask for this column to invalid - parent_validity->SetInvalid(i); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, list_entry_locations, nullptr, - entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapScatterArrayVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_type = ArrayType::GetChildType(v.GetType()); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - UnifiedVectorFormat child_vdata; - child_vector.ToUnifiedFormat(ArrayVector::GetTotalSize(v), child_vdata); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - // Set if the whole array itself is null in the parent entry - auto source_idx = vdata.sel->get_index(sel.get_index(i) + offset); - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - - // Now we can serialize the array itself - // Every array starts with a validity mask for the children - data_ptr_t array_validitymask_location = key_locations[i]; - memset(array_validitymask_location, -1, array_validitymask_size); - key_locations[i] += array_validitymask_size; - - NestedValidity array_parent_validity(array_validitymask_location); - - // If the array contains variable size entries, we reserve spaces for them here - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // Then comes the elements - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - while (elem_remaining > 0) { - // the array elements can span multiple vectors, so we divide it into chunks - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // Setup the locations for the elements - if (child_type_is_var_size) { - // The elements are variable sized - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += array_entry_sizes[elem_idx]; - - // Now store the size of the entry - Store(array_entry_sizes[elem_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } else { - // The elements are constant sized - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - } - } - - RowOperations::HeapScatter(child_vector, ArrayVector::GetTotalSize(v), - *FlatVector::IncrementalSelectionVector(), chunk_size, array_entry_locations, - &array_parent_validity, array_start); - - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - array_parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, key_locations, - parent_validity, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::ARRAY: - HeapScatterArrayVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp deleted file mode 100644 index a85a71997..000000000 --- a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 2, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 2); - key_location += width - 2; - } - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width, array_offset); - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { -#ifdef DEBUG - // initialize to verify written width later - auto key_locations_copy = make_uniq_array(ser_count); - for (idx_t i = 0; i < ser_count; i++) { - key_locations_copy[i] = key_locations[i]; - } -#endif - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - case PhysicalType::ARRAY: - RadixScatterArrayVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } - -#ifdef DEBUG - for (idx_t i = 0; i < ser_count; i++) { - D_ASSERT(key_locations[i] == key_locations_copy[i] + width); - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_scatter.cpp b/src/duckdb/src/common/row_operations/row_scatter.cpp deleted file mode 100644 index 1912d2484..000000000 --- a/src/duckdb/src/common/row_operations/row_scatter.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no, const idx_t col_count) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx], col_count); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t col_count) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row, col_count); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), UnsafeNumericCast(str.GetSize())); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - NestedValidity parent_validity(validitymask_locations, col_no); - RowOperations::HeapScatter(vec, vcount, sel, count, data_locations, &parent_validity); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - idx_t column_count = layout.ColumnCount(); - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row, column_count).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(NumericCast(entry_sizes[i]), data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/memory_stream.cpp b/src/duckdb/src/common/serializer/memory_stream.cpp index eb4fb6d57..277b94460 100644 --- a/src/duckdb/src/common/serializer/memory_stream.cpp +++ b/src/duckdb/src/common/serializer/memory_stream.cpp @@ -59,6 +59,12 @@ MemoryStream &MemoryStream::operator=(MemoryStream &&other) noexcept { } void MemoryStream::WriteData(const_data_ptr_t source, idx_t write_size) { + GrowCapacity(write_size); + memcpy(data + position, source, write_size); + position += write_size; +} + +void MemoryStream::GrowCapacity(idx_t write_size) { const auto old_capacity = capacity; while (position + write_size > capacity) { if (allocator) { @@ -70,8 +76,6 @@ void MemoryStream::WriteData(const_data_ptr_t source, idx_t write_size) { if (capacity != old_capacity) { data = allocator->ReallocateData(data, old_capacity, capacity); } - memcpy(data + position, source, write_size); - position += write_size; } void MemoryStream::ReadData(data_ptr_t destination, idx_t read_size) { diff --git a/src/duckdb/src/common/sort/comparators.cpp b/src/duckdb/src/common/sort/comparators.cpp deleted file mode 100644 index 4df4cccc4..000000000 --- a/src/duckdb/src/common/sort/comparators.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/sort/comparators.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr, sort_layout.column_count); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col] && tie_string.GetSize() > 0) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::ARRAY: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout) && !TieIsBreakable(tie_col, r_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - case PhysicalType::ARRAY: - return CompareArrayAndAdvance(l_ptr, r_ptr, ArrayType::GetChildType(type), valid, ArrayType::GetSize(type)); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr, types.size()); - ValidityBytes right_validity(right_ptr, types.size()); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid, idx_t array_size) { - if (!valid) { - return 0; - } - - // Load array validity masks - ValidityBytes left_validity(left_ptr, array_size); - ValidityBytes right_validity(right_ptr, array_size); - left_ptr += (array_size + 7) / 8; - right_ptr += (array_size + 7) / 8; - - int comp_res = 0; - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT8: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT128: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INTERVAL: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized array entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += array_size * sizeof(idx_t); - right_ptr += array_size * sizeof(idx_t); - for (idx_t i = 0; i < array_size; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareArrayAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr, left_len); - ValidityBytes right_validity(right_ptr, right_len); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/full_sort.cpp b/src/duckdb/src/common/sort/full_sort.cpp new file mode 100644 index 000000000..1fde6661b --- /dev/null +++ b/src/duckdb/src/common/sort/full_sort.cpp @@ -0,0 +1,368 @@ +#include "duckdb/common/sorting/full_sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// FullSortGroup +//===--------------------------------------------------------------------===// +class FullSortGroup { +public: + FullSortGroup(ClientContext &client, const Sort &sort); + + atomic count; + + // Sink + unique_ptr sort_global; + + // Source + unique_ptr sort_source; +}; + +FullSortGroup::FullSortGroup(ClientContext &client, const Sort &sort) : count(0) { + sort_global = sort.GetGlobalSinkState(client); +} + +//===--------------------------------------------------------------------===// +// FullSortGlobalSinkState +//===--------------------------------------------------------------------===// +class FullSortGlobalSinkState : public GlobalSinkState { +public: + using HashGroupPtr = unique_ptr; + + FullSortGlobalSinkState(ClientContext &client, const FullSort &full_sort); + + // OVER(PARTITION BY...) (hash grouping) + ProgressData GetSinkProgress(ClientContext &context, const ProgressData source_progress) const; + + //! System and query state + ClientContext &client; + const FullSort &full_sort; + mutable mutex lock; + + // OVER(...) (sorting) + HashGroupPtr hash_group; + + // Threading + atomic count; +}; + +FullSortGlobalSinkState::FullSortGlobalSinkState(ClientContext &client, const FullSort &full_sort) + : client(client), full_sort(full_sort), count(0) { + // Sort early into a dedicated hash group if we only sort. + hash_group = make_uniq(client, *full_sort.sort); +} + +ProgressData FullSortGlobalSinkState::GetSinkProgress(ClientContext &client, const ProgressData source) const { + ProgressData result; + result.done = source.done / 2; + result.total = source.total; + result.invalid = source.invalid; + + // Sort::GetSinkProgress assumes that there is only 1 sort. + // So we just use it to figure out how many rows have been sorted. + const ProgressData zero_progress; + lock_guard guard(lock); + const auto &sort = full_sort.sort; + + const auto group_progress = sort->GetSinkProgress(client, *hash_group->sort_global, zero_progress); + result.done += group_progress.done; + result.invalid = result.invalid || group_progress.invalid; + + return result; +} + +SinkFinalizeType FullSort::Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const { + auto &gsink = finalize.global_state.Cast(); + + // Did we get any data? + if (!gsink.count) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + // OVER(ORDER BY...) + auto &hash_group = gsink.hash_group; + auto &global_sink = *hash_group->sort_global; + OperatorSinkFinalizeInput hfinalize {global_sink, finalize.interrupt_state}; + sort->Finalize(client, hfinalize); + hash_group->sort_source = sort->GetGlobalSourceState(client, global_sink); + + return SinkFinalizeType::READY; +} + +ProgressData FullSort::GetSinkProgress(ClientContext &client, GlobalSinkState &gstate, + const ProgressData source) const { + auto &gsink = gstate.Cast(); + return gsink.GetSinkProgress(client, source); +} + +//===--------------------------------------------------------------------===// +// FullSortLocalSinkState +//===--------------------------------------------------------------------===// +// Formerly PartitionLocalSinkState +class FullSortLocalSinkState : public LocalSinkState { +public: + using LocalSortStatePtr = unique_ptr; + + FullSortLocalSinkState(ExecutionContext &context, const FullSort &full_sort); + + //! Global state + const FullSort &full_sort; + + //! Shared expression evaluation + ExpressionExecutor sort_exec; + DataChunk group_chunk; + DataChunk sort_chunk; + DataChunk payload_chunk; + + //! Compute the hash values + void Hash(DataChunk &input_chunk, Vector &hash_vector); + //! Merge the state into the global state. + void Combine(ExecutionContext &context); + + // OVER(ORDER BY...) (only sorting) + LocalSortStatePtr sort_local; + + // OVER() (no sorting) + unique_ptr unsorted; + ColumnDataAppendState unsorted_append; +}; + +FullSortLocalSinkState::FullSortLocalSinkState(ExecutionContext &context, const FullSort &full_sort) + : full_sort(full_sort), sort_exec(context.client) { + vector sort_types; + for (const auto &expr : full_sort.sort_exprs) { + sort_types.emplace_back(expr->return_type); + sort_exec.AddExpression(*expr); + } + sort_chunk.Initialize(context.client, sort_types); + + auto payload_types = full_sort.payload_types; + + // OVER(ORDER BY...) + auto &sort = *full_sort.sort; + sort_local = sort.GetLocalSinkState(context); + payload_chunk.Initialize(context.client, payload_types); +} + +SinkResultType FullSort::Sink(ExecutionContext &context, DataChunk &input_chunk, OperatorSinkInput &sink) const { + auto &gstate = sink.global_state.Cast(); + auto &lstate = sink.local_state.Cast(); + gstate.count += input_chunk.size(); + + // Payload prefix is the input data + auto &payload_chunk = lstate.payload_chunk; + payload_chunk.Reset(); + for (column_t i = 0; i < input_chunk.ColumnCount(); ++i) { + payload_chunk.data[i].Reference(input_chunk.data[i]); + } + + // Compute any sort columns that are not references and append them to the end of the payload + auto &sort_chunk = lstate.sort_chunk; + auto &sort_exec = lstate.sort_exec; + if (!sort_exprs.empty()) { + sort_chunk.Reset(); + sort_exec.Execute(input_chunk, sort_chunk); + for (column_t i = 0; i < sort_chunk.ColumnCount(); ++i) { + payload_chunk.data[input_chunk.ColumnCount() + i].Reference(sort_chunk.data[i]); + } + } + + // Append a forced payload column + if (force_payload) { + auto &vec = payload_chunk.data[input_chunk.ColumnCount() + sort_chunk.ColumnCount()]; + D_ASSERT(vec.GetType().id() == LogicalTypeId::BOOLEAN); + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + + payload_chunk.SetCardinality(input_chunk); + + // OVER(ORDER BY...) + auto &sort_local = lstate.sort_local; + D_ASSERT(sort_local); + auto &hash_group = *gstate.hash_group; + OperatorSinkInput input {*hash_group.sort_global, *sort_local, sink.interrupt_state}; + sort->Sink(context, payload_chunk, input); + hash_group.count += payload_chunk.size(); + + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType FullSort::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { + auto &gstate = combine.global_state.Cast(); + auto &lstate = combine.local_state.Cast(); + + // OVER(ORDER BY...) + D_ASSERT(lstate.sort_local); + auto &hash_group = *gstate.hash_group; + OperatorSinkCombineInput input {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; + sort->Combine(context, input); + lstate.sort_local.reset(); + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// FullSortGlobalSourceState +//===--------------------------------------------------------------------===// +class FullSortGlobalSourceState : public GlobalSourceState { +public: + using ChunkRow = FullSort::ChunkRow; + using ChunkRows = FullSort::ChunkRows; + + FullSortGlobalSourceState(ClientContext &client, FullSortGlobalSinkState &gsink); + + FullSortGlobalSinkState &gsink; + ChunkRows chunk_rows; +}; + +FullSortGlobalSourceState::FullSortGlobalSourceState(ClientContext &client, FullSortGlobalSinkState &gsink) + : gsink(gsink) { + if (!gsink.count) { + return; + } + + // One sorted group + ChunkRow chunk_row; + + auto &hash_group = gsink.hash_group; + if (hash_group) { + chunk_row.count = hash_group->count; + chunk_row.chunks = (chunk_row.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE; + } + + chunk_rows.emplace_back(chunk_row); +} + +//===--------------------------------------------------------------------===// +// FullSort +//===--------------------------------------------------------------------===// +FullSort::FullSort(ClientContext &client, const vector &order_bys, const Types &input_types, + bool require_payload) + : SortStrategy(input_types) { + // We have to compute ordering expressions ourselves and materialise them. + // To do this, we scan the orders and add generate extra payload columns that we can reference. + for (const auto &order : order_bys) { + orders.emplace_back(order.Copy()); + } + + for (auto &order : orders) { + auto &expr = *order.expression; + if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { + auto &ref = expr.Cast(); + sort_ids.emplace_back(ref.index); + continue; + } + + // Real expression - replace with a ref and save the expression + auto saved = std::move(order.expression); + const auto type = saved->return_type; + const auto idx = payload_types.size(); + order.expression = make_uniq(type, idx); + sort_ids.emplace_back(idx); + payload_types.emplace_back(type); + sort_exprs.emplace_back(std::move(saved)); + } + + // If a payload column is required, check whether there is one already + if (require_payload) { + // Watch out for duplicate sort keys! + unordered_set sort_set(sort_ids.begin(), sort_ids.end()); + force_payload = (sort_set.size() >= payload_types.size()); + if (force_payload) { + payload_types.emplace_back(LogicalType::BOOLEAN); + } + } + vector projection_map; + sort = make_uniq(client, orders, payload_types, projection_map); +} + +unique_ptr FullSort::GetGlobalSinkState(ClientContext &client) const { + return make_uniq(client, *this); +} + +unique_ptr FullSort::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr FullSort::GetGlobalSourceState(ClientContext &client, GlobalSinkState &sink) const { + return make_uniq(client, sink.Cast()); +} + +const FullSort::ChunkRows &FullSort::GetHashGroups(GlobalSourceState &gstate) const { + auto &gsource = gstate.Cast(); + return gsource.chunk_rows; +} + +SourceResultType FullSort::MaterializeSortedData(ExecutionContext &context, bool build_runs, + OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_group; + + auto &sort_global = *hash_group.sort_source; + auto sort_local = sort->GetLocalSourceState(context, sort_global); + + OperatorSourceInput input {sort_global, *sort_local, source.interrupt_state}; + if (build_runs) { + return sort->MaterializeSortedRun(context, input); + } else { + return sort->MaterializeColumnData(context, input); + } +} + +SourceResultType FullSort::MaterializeColumnData(ExecutionContext &execution, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeSortedData(execution, false, source); +} + +FullSort::HashGroupPtr FullSort::GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_group; + + D_ASSERT(hash_bin == 0); + + auto &sort_global = *hash_group.sort_source; + + OperatorSourceInput input {sort_global, source.local_state, source.interrupt_state}; + auto result = sort->GetColumnData(input); + hash_group.sort_source.reset(); + + // Just because MaterializeColumnData returned FINISHED doesn't mean that the same thread will + // get the result... + if (result && result->Count() == hash_group.count) { + return result; + } + + return nullptr; +} + +SourceResultType FullSort::MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeSortedData(context, true, source); +} + +FullSort::SortedRunPtr FullSort::GetSortedRun(ClientContext &client, idx_t hash_bin, + OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_group; + + D_ASSERT(hash_bin == 0); + + auto &sort_global = *hash_group.sort_source; + + auto result = sort->GetSortedRun(sort_global); + if (!result) { + D_ASSERT(hash_group.count == 0); + result = make_uniq(client, *sort, false); + } + + hash_group.sort_source.reset(); + + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/hashed_sort.cpp b/src/duckdb/src/common/sort/hashed_sort.cpp similarity index 60% rename from src/duckdb/src/common/sorting/hashed_sort.cpp rename to src/duckdb/src/common/sort/hashed_sort.cpp index 5571e0bc3..c84fb4723 100644 --- a/src/duckdb/src/common/sorting/hashed_sort.cpp +++ b/src/duckdb/src/common/sort/hashed_sort.cpp @@ -1,6 +1,6 @@ #include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" @@ -9,37 +9,35 @@ namespace duckdb { //===--------------------------------------------------------------------===// // HashedSortGroup //===--------------------------------------------------------------------===// -// Formerly PartitionGlobalHashGroup class HashedSortGroup { public: using Orders = vector; using Types = vector; - HashedSortGroup(ClientContext &client, optional_ptr sort, idx_t group_idx); + HashedSortGroup(ClientContext &client, Sort &sort, idx_t group_idx); const idx_t group_idx; + atomic count; // Sink - optional_ptr sort; + Sort &sort; unique_ptr sort_global; // Source + mutex scan_lock; + TupleDataParallelScanState parallel_scan; atomic tasks_completed; unique_ptr sort_source; - unique_ptr sorted; }; -HashedSortGroup::HashedSortGroup(ClientContext &client, optional_ptr sort, idx_t group_idx) - : group_idx(group_idx), sort(sort), tasks_completed(0) { - if (sort) { - sort_global = sort->GetGlobalSinkState(client); - } +HashedSortGroup::HashedSortGroup(ClientContext &client, Sort &sort, idx_t group_idx) + : group_idx(group_idx), count(0), sort(sort), tasks_completed(0) { + sort_global = sort.GetGlobalSinkState(client); } //===--------------------------------------------------------------------===// // HashedSortGlobalSinkState //===--------------------------------------------------------------------===// -// Formerly PartitionGlobalSinkState class HashedSortGlobalSinkState : public GlobalSinkState { public: using HashGroupPtr = unique_ptr; @@ -53,11 +51,13 @@ class HashedSortGlobalSinkState : public GlobalSinkState { // OVER(PARTITION BY...) (hash grouping) unique_ptr CreatePartition(idx_t new_bits) const; + void SyncPartitioning(const HashedSortGlobalSinkState &other); void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &partition_append); void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); ProgressData GetSinkProgress(ClientContext &context, const ProgressData source_progress) const; //! System and query state + ClientContext &client; const HashedSort &hashed_sort; BufferManager &buffer_manager; Allocator &allocator; @@ -83,9 +83,8 @@ class HashedSortGlobalSinkState : public GlobalSinkState { }; HashedSortGlobalSinkState::HashedSortGlobalSinkState(ClientContext &client, const HashedSort &hashed_sort) - : hashed_sort(hashed_sort), buffer_manager(BufferManager::GetBufferManager(client)), + : client(client), hashed_sort(hashed_sort), buffer_manager(BufferManager::GetBufferManager(client)), allocator(Allocator::Get(client)), fixed_bits(0), max_bits(1), count(0) { - const auto memory_per_thread = PhysicalOperator::GetMaxThreadMemory(client); const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); while (max_bits < 8 && (thread_pages >> max_bits) > 1) { @@ -93,28 +92,18 @@ HashedSortGlobalSinkState::HashedSortGlobalSinkState(ClientContext &client, cons } grouping_types_ptr = make_shared_ptr(); - auto &partitions = hashed_sort.partitions; - auto &orders = hashed_sort.orders; auto &payload_types = hashed_sort.payload_types; - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types_ptr->Initialize(payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - auto new_group = make_uniq(hashed_sort.client, *hashed_sort.sort, idx_t(0)); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - Rehash(hashed_sort.estimated_cardinality); - } - } + auto types = payload_types; + types.push_back(LogicalType::HASH); + grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); + Rehash(hashed_sort.estimated_cardinality); } unique_ptr HashedSortGlobalSinkState::CreatePartition(idx_t new_bits) const { auto &payload_types = hashed_sort.payload_types; const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types_ptr, new_bits, hash_col_idx); + return make_uniq(buffer_manager, grouping_types_ptr, MemoryTag::WINDOW, new_bits, + hash_col_idx); } void HashedSortGlobalSinkState::Rehash(idx_t cardinality) { @@ -146,7 +135,7 @@ void HashedSortGlobalSinkState::SyncLocalPartition(GroupingPartition &local_part // If the local partition is now too small, flush it and reallocate auto new_partition = CreatePartition(new_bits); local_partition->FlushAppendState(*local_append); - local_partition->Repartition(hashed_sort.client, *new_partition); + local_partition->Repartition(client, *new_partition); local_partition = std::move(new_partition); local_append = make_uniq(); @@ -172,6 +161,15 @@ void HashedSortGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_pa SyncLocalPartition(local_partition, partition_append); } +void HashedSortGlobalSinkState::SyncPartitioning(const HashedSortGlobalSinkState &other) { + fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; + + const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; + if (fixed_bits != old_bits) { + grouping_data = CreatePartition(fixed_bits); + } +} + void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { if (!local_partition) { @@ -200,9 +198,12 @@ void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_p auto &group_data = groups[group_idx]; if (group_data->Count()) { - hash_group = make_uniq(hashed_sort.client, *hashed_sort.sort, group_idx); + hash_group = make_uniq(client, *hashed_sort.sort, group_idx); } } + + // Combine the thread data into the global data + grouping_data->Combine(*local_partition); } ProgressData HashedSortGlobalSinkState::GetSinkProgress(ClientContext &client, const ProgressData source) const { @@ -237,19 +238,23 @@ SinkFinalizeType HashedSort::Finalize(ClientContext &client, OperatorSinkFinaliz return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } - // OVER() - if (!sort) { - return SinkFinalizeType::READY; - } - - // OVER(...) + // OVER(PARTITION BY...) + auto &partitions = gsink.grouping_data->GetPartitions(); D_ASSERT(!gsink.hash_groups.empty()); - for (auto &hash_group : gsink.hash_groups) { + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &partition = *partitions[hash_bin]; + if (!partition.Count()) { + continue; + } + + auto &hash_group = gsink.hash_groups[hash_bin]; if (!hash_group) { continue; } - OperatorSinkFinalizeInput hfinalize {*hash_group->sort_global, finalize.interrupt_state}; - sort->Finalize(client, hfinalize); + + // Prepare to scan into the sort + auto ¶llel_scan = hash_group->parallel_scan; + partition.InitializeScan(parallel_scan, partition_ids); } return SinkFinalizeType::READY; @@ -264,7 +269,6 @@ ProgressData HashedSort::GetSinkProgress(ClientContext &client, GlobalSinkState //===--------------------------------------------------------------------===// // HashedSortLocalSinkState //===--------------------------------------------------------------------===// -// Formerly PartitionLocalSinkState class HashedSortLocalSinkState : public LocalSinkState { public: using LocalSortStatePtr = unique_ptr; @@ -292,20 +296,11 @@ class HashedSortLocalSinkState : public LocalSinkState { // OVER(PARTITION BY...) (hash grouping) GroupingPartition local_grouping; GroupingAppend grouping_append; - - // OVER(ORDER BY...) (only sorting) - LocalSortStatePtr sort_local; - InterruptState interrupt; - - // OVER() (no sorting) - unique_ptr unsorted; - ColumnDataAppendState unsorted_append; }; HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, const HashedSort &hashed_sort) : hashed_sort(hashed_sort), allocator(Allocator::Get(context.client)), hash_exec(context.client), sort_exec(context.client) { - vector group_types; for (idx_t prt_idx = 0; prt_idx < hashed_sort.partitions.size(); prt_idx++) { auto &pexpr = *hashed_sort.partitions[prt_idx].expression.get(); @@ -320,31 +315,16 @@ HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, co } sort_chunk.Initialize(context.client, sort_types); - if (hashed_sort.sort_col_count) { - auto payload_types = hashed_sort.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < hashed_sort.orders.size(); ord_idx++) { - auto &pexpr = *hashed_sort.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - hash_exec.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &sort = *hashed_sort.sort; - sort_local = sort.GetLocalSinkState(context); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - unsorted = make_uniq(context.client, hashed_sort.payload_types); - unsorted->InitializeAppend(unsorted_append); - } + auto hash_types = hashed_sort.payload_types; + group_chunk.Initialize(allocator, group_types); + hash_types.emplace_back(LogicalType::HASH); + payload_chunk.Initialize(allocator, hash_types); +} + +void HashedSort::Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const { + auto &src = source.Cast(); + auto &tgt = target.Cast(); + tgt.SyncPartitioning(src); } void HashedSortLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { @@ -365,17 +345,6 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun auto &lstate = sink.local_state.Cast(); gstate.count += input_chunk.size(); - // Window::Sink: - // PartitionedTupleData::Append - // Sort::Sink - // ColumnDataCollection::Append - - // OVER() - if (gstate.hashed_sort.sort_col_count == 0) { - lstate.unsorted->Append(lstate.unsorted_append, input_chunk); - return SinkResultType::NEED_MORE_INPUT; - } - // Payload prefix is the input data auto &payload_chunk = lstate.payload_chunk; payload_chunk.Reset(); @@ -393,17 +362,17 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun payload_chunk.data[input_chunk.ColumnCount() + i].Reference(sort_chunk.data[i]); } } - payload_chunk.SetCardinality(input_chunk); - // OVER(ORDER BY...) - auto &sort_local = lstate.sort_local; - if (sort_local) { - auto &hash_group = *gstate.hash_groups[0]; - OperatorSinkInput input {*hash_group.sort_global, *sort_local, sink.interrupt_state}; - sort->Sink(context, payload_chunk, input); - return SinkResultType::NEED_MORE_INPUT; + // Append a forced payload column + if (force_payload) { + auto &vec = payload_chunk.data[input_chunk.ColumnCount() + sort_chunk.ColumnCount()]; + D_ASSERT(vec.GetType().id() == LogicalTypeId::BOOLEAN); + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); } + payload_chunk.SetCardinality(input_chunk); + // OVER(PARTITION BY...) auto &hash_vector = payload_chunk.data.back(); lstate.Hash(input_chunk, hash_vector); @@ -423,40 +392,6 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin auto &gstate = combine.global_state.Cast(); auto &lstate = combine.local_state.Cast(); - // Window::Combine: - // Sort::Sink then Sort::Combine (per hash partition) - // Sort::Combine - // ColumnDataCollection::Combine - - // OVER() - if (gstate.hashed_sort.sort_col_count == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - auto &hash_groups = gstate.hash_groups; - if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - auto &unsorted = *hash_groups[0]->sorted; - if (lstate.unsorted) { - unsorted.Combine(*lstate.unsorted); - lstate.unsorted.reset(); - } - } else { - auto new_group = make_uniq(context.client, sort, idx_t(0)); - new_group->sorted = std::move(lstate.unsorted); - hash_groups.emplace_back(std::move(new_group)); - } - return SinkCombineResultType::FINISHED; - } - - // OVER(ORDER BY...) - if (lstate.sort_local) { - auto &hash_group = *gstate.hash_groups[0]; - OperatorSinkCombineInput input {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; - sort->Combine(context, input); - lstate.sort_local.reset(); - return SinkCombineResultType::FINISHED; - } - // OVER(PARTITION BY...) auto &local_grouping = lstate.local_grouping; if (!local_grouping) { @@ -467,152 +402,85 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin auto &grouping_append = lstate.grouping_append; gstate.CombineLocalPartition(local_grouping, grouping_append); - // Don't scan the hash column - vector column_ids; - for (column_t i = 0; i < payload_types.size(); ++i) { - column_ids.emplace_back(i); - } + return SinkCombineResultType::FINISHED; +} + +void HashedSort::SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize) { + auto &gstate = finalize.global_state.Cast(); // Loop over the partitions and add them to each hash group's global sort state - TupleDataScanState scan_state; - DataChunk chunk; - auto &partitions = local_grouping->GetPartitions(); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &partitions = gstate.grouping_data->GetPartitions(); + if (hash_bin < partitions.size()) { auto &partition = *partitions[hash_bin]; if (!partition.Count()) { - continue; - } - - partition.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); - if (chunk.data.empty()) { - partition.InitializeScanChunk(scan_state, chunk); + return; } auto &hash_group = *gstate.hash_groups[hash_bin]; - lstate.sort_local = sort->GetLocalSinkState(context); - OperatorSinkInput sink {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; - while (partition.Scan(scan_state, chunk)) { - sort->Sink(context, chunk, sink); - } + auto ¶llel_scan = hash_group.parallel_scan; - OperatorSinkCombineInput lcombine {*hash_group.sort_global, *lstate.sort_local, combine.interrupt_state}; - sort->Combine(context, lcombine); - } + DataChunk chunk; + partition.InitializeScanChunk(parallel_scan.scan_state, chunk); + TupleDataLocalScanState local_scan; + partition.InitializeScan(local_scan); - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// HashedSortMaterializeTask -//===--------------------------------------------------------------------===// -class HashedSortMaterializeTask : public ExecutorTask { -public: - HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, const PhysicalOperator &op, - HashedSortGroup &hash_group, idx_t tasks_scheduled); - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - - string TaskType() const override { - return "HashedSortMaterializeTask"; - } - -private: - Pipeline &pipeline; - HashedSortGroup &hash_group; - const idx_t tasks_scheduled; -}; + auto sort_local = sort->GetLocalSinkState(context); + OperatorSinkInput sink {*hash_group.sort_global, *sort_local, finalize.interrupt_state}; + idx_t combined = 0; + while (partition.Scan(hash_group.parallel_scan, local_scan, chunk)) { + sort->Sink(context, chunk, sink); + combined += chunk.size(); + } -HashedSortMaterializeTask::HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, - const PhysicalOperator &op, HashedSortGroup &hash_group, - idx_t tasks_scheduled) - : ExecutorTask(pipeline.GetClientContext(), std::move(event), op), pipeline(pipeline), hash_group(hash_group), - tasks_scheduled(tasks_scheduled) { -} + OperatorSinkCombineInput combine {*hash_group.sort_global, *sort_local, finalize.interrupt_state}; + sort->Combine(context, combine); + hash_group.count += combined; -TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); - auto &sort = *hash_group.sort; - auto &sort_global = *hash_group.sort_source; - auto sort_local = sort.GetLocalSourceState(execution, sort_global); - InterruptState interrupt((weak_ptr(shared_from_this()))); - OperatorSourceInput input {sort_global, *sort_local, interrupt}; - sort.MaterializeColumnData(execution, input); - if (++hash_group.tasks_completed == tasks_scheduled) { - hash_group.sorted = sort.GetColumnData(input); + // Whoever finishes last can Finalize + lock_guard finalize_guard(hash_group.scan_lock); + if (hash_group.count == partition.Count() && !hash_group.sort_source) { + OperatorSinkFinalizeInput lfinalize {*hash_group.sort_global, finalize.interrupt_state}; + sort->Finalize(context.client, lfinalize); + hash_group.sort_source = sort->GetGlobalSourceState(context.client, *hash_group.sort_global); + } } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; } //===--------------------------------------------------------------------===// -// HashedSortMaterializeEvent +// HashedSortGlobalSourceState //===--------------------------------------------------------------------===// -// Formerly PartitionMergeEvent -class HashedSortMaterializeEvent : public BasePipelineEvent { +class HashedSortGlobalSourceState : public GlobalSourceState { public: - HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, const PhysicalOperator &op); + using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; + using ChunkRow = HashedSort::ChunkRow; + using ChunkRows = HashedSort::ChunkRows; - HashedSortGlobalSinkState &gstate; - const PhysicalOperator &op; + HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink); -public: - void Schedule() override; + HashedSortGlobalSinkState &gsink; + ChunkRows chunk_rows; }; -HashedSortMaterializeEvent::HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, - const PhysicalOperator &op) - : BasePipelineEvent(pipeline), gstate(gstate), op(op) { -} - -void HashedSortMaterializeEvent::Schedule() { - auto &client = pipeline->GetClientContext(); - - // Schedule as many tasks per hash group as the sort will allow - auto &ts = TaskScheduler::GetScheduler(client); - const auto num_threads = NumericCast(ts.NumberOfThreads()); - auto &sort = *gstate.hashed_sort.sort; - - vector> merge_tasks; - for (auto &hash_group : gstate.hash_groups) { - if (!hash_group) { - continue; - } - auto &global_sink = *hash_group->sort_global; - hash_group->sort_source = sort.GetGlobalSourceState(client, global_sink); - const auto tasks_scheduled = MinValue(num_threads, hash_group->sort_source->MaxThreads()); - for (idx_t t = 0; t < tasks_scheduled; ++t) { - merge_tasks.emplace_back( - make_uniq(*pipeline, shared_from_this(), op, *hash_group, tasks_scheduled)); - } +HashedSortGlobalSourceState::HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink) + : gsink(gsink) { + if (!gsink.count) { + return; } - SetTasks(std::move(merge_tasks)); -} - -//===--------------------------------------------------------------------===// -// HashedSortGlobalSourceState -//===--------------------------------------------------------------------===// -class HashedSortGlobalSourceState : public GlobalSourceState { -public: - using HashGroupPtr = unique_ptr; + auto &partitions = gsink.grouping_data->GetPartitions(); + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + ChunkRow chunk_row; - HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink) { - if (!gsink.count) { - return; - } - hash_groups.resize(gsink.hash_groups.size()); - for (auto &hash_group : gsink.hash_groups) { - if (!hash_group) { - continue; - } - const auto group_idx = hash_group->group_idx; - hash_groups[group_idx] = std::move(hash_group->sorted); + auto &hash_group = gsink.hash_groups[hash_bin]; + if (hash_group) { + chunk_row.count = partitions[hash_bin]->Count(); + chunk_row.chunks = (chunk_row.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE; } - } - vector hash_groups; -}; + chunk_rows.emplace_back(chunk_row); + } +} //===--------------------------------------------------------------------===// // HashedSort @@ -620,7 +488,6 @@ class HashedSortGlobalSourceState : public GlobalSourceState { void HashedSort::GenerateOrderings(Orders &partitions, Orders &orders, const vector> &partition_bys, const Orders &order_bys, const vector> &partition_stats) { - // we sort by both 1) partition by expression list and 2) order by expressions const auto partition_cols = partition_bys.size(); for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { @@ -642,15 +509,11 @@ void HashedSort::GenerateOrderings(Orders &partitions, Orders &orders, HashedSort::HashedSort(ClientContext &client, const vector> &partition_bys, const vector &order_bys, const Types &input_types, - const vector> &partition_stats, idx_t estimated_cardinality) - : client(client), estimated_cardinality(estimated_cardinality), payload_types(input_types) { + const vector> &partition_stats, idx_t estimated_cardinality, + bool require_payload) + : SortStrategy(input_types), estimated_cardinality(estimated_cardinality) { GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - // The payload prefix is the same as the input schema - for (column_t i = 0; i < payload_types.size(); ++i) { - scan_ids.emplace_back(i); - } - // We have to compute ordering expressions ourselves and materialise them. // To do this, we scan the orders and add generate extra payload columns that we can reference. for (auto &order : orders) { @@ -671,11 +534,23 @@ HashedSort::HashedSort(ClientContext &client, const vector projection_map; - sort = make_uniq(client, orders, payload_types, projection_map); + // If a payload column is required, check whether there is one already + if (require_payload) { + // Watch out for duplicate sort keys! + unordered_set sort_set(sort_ids.begin(), sort_ids.end()); + force_payload = (sort_set.size() >= payload_types.size()); + if (force_payload) { + payload_types.emplace_back(LogicalType::BOOLEAN); + } } + + // Remember the full set of materialised partition columns + for (column_t i = 0; i < payload_types.size(); ++i) { + partition_ids.emplace_back(i); + } + + vector projection_map; + sort = make_uniq(client, orders, payload_types, projection_map); } unique_ptr HashedSort::GetGlobalSinkState(ClientContext &client) const { @@ -686,39 +561,92 @@ unique_ptr HashedSort::GetLocalSinkState(ExecutionContext &conte return make_uniq(context, *this); } -unique_ptr HashedSort::GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const { +unique_ptr HashedSort::GetGlobalSourceState(ClientContext &client, GlobalSinkState &sink) const { return make_uniq(client, sink.Cast()); } -unique_ptr HashedSort::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(); +const HashedSort::ChunkRows &HashedSort::GetHashGroups(GlobalSourceState &gstate) const { + auto &gsource = gstate.Cast(); + return gsource.chunk_rows; } -vector &HashedSort::GetHashGroups(GlobalSourceState &gstate) const { - auto &gsource = gstate.Cast(); - return gsource.hash_groups; +static SourceResultType MaterializeHashGroupData(ExecutionContext &context, idx_t hash_bin, bool build_runs, + OperatorSourceInput &source) { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; + + // OVER(PARTITION BY...) + if (gsink.grouping_data) { + lock_guard reset_guard(hash_group.scan_lock); + auto &partitions = gsink.grouping_data->GetPartitions(); + if (hash_bin < partitions.size()) { + // Release the memory now that we have finished scanning it. + partitions[hash_bin].reset(); + } + } + + auto &sort = hash_group.sort; + auto &sort_global = *hash_group.sort_source; + auto sort_local = sort.GetLocalSourceState(context, sort_global); + + OperatorSourceInput input {sort_global, *sort_local, source.interrupt_state}; + if (build_runs) { + return sort.MaterializeSortedRun(context, input); + } else { + return sort.MaterializeColumnData(context, input); + } } -SinkFinalizeType HashedSort::MaterializeHashGroups(Pipeline &pipeline, Event &event, const PhysicalOperator &op, - OperatorSinkFinalizeInput &finalize) const { - auto &gsink = finalize.global_state.Cast(); +SourceResultType HashedSort::MaterializeColumnData(ExecutionContext &execution, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeHashGroupData(execution, hash_bin, false, source); +} - // OVER() - if (sort_col_count == 0) { - auto &hash_group = *gsink.hash_groups[0]; - auto &unsorted = *hash_group.sorted; - if (!unsorted.Count()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; +HashedSort::HashGroupPtr HashedSort::GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; + + auto &sort = hash_group.sort; + auto &sort_global = *hash_group.sort_source; + + OperatorSourceInput input {sort_global, source.local_state, source.interrupt_state}; + auto result = sort.GetColumnData(input); + hash_group.sort_source.reset(); + + // Just because MaterializeColumnData returned FINISHED doesn't mean that the same thread will + // get the result... + if (result && result->Count() == hash_group.count) { + return result; } - // Schedule all the sorts for maximum thread utilisation - auto sort_event = make_shared_ptr(gsink, pipeline, op); - event.InsertEvent(std::move(sort_event)); + return nullptr; +} - return SinkFinalizeType::READY; +SourceResultType HashedSort::MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const { + return MaterializeHashGroupData(context, hash_bin, true, source); +} + +HashedSort::SortedRunPtr HashedSort::GetSortedRun(ClientContext &client, idx_t hash_bin, + OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_groups[hash_bin]; + + auto &sort = hash_group.sort; + auto &sort_global = *hash_group.sort_source; + + auto result = sort.GetSortedRun(sort_global); + if (!result) { + D_ASSERT(hash_group.count == 0); + result = make_uniq(client, sort, false); + } + + hash_group.sort_source.reset(); + + return result; } } // namespace duckdb diff --git a/src/duckdb/src/common/sort/merge_sorter.cpp b/src/duckdb/src/common/sort/merge_sorter.cpp deleted file mode 100644 index c670fd574..000000000 --- a/src/duckdb/src/common/sort/merge_sorter.cpp +++ /dev/null @@ -1,667 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - // Check for interrupts after merging a partition - if (state.context.interrupted) { - throw InterruptException(); - } - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/natural_sort.cpp b/src/duckdb/src/common/sort/natural_sort.cpp new file mode 100644 index 000000000..d8101a34a --- /dev/null +++ b/src/duckdb/src/common/sort/natural_sort.cpp @@ -0,0 +1,216 @@ +#include "duckdb/common/sorting/natural_sort.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// NaturalSortGroup +//===--------------------------------------------------------------------===// +class NaturalSortGroup { +public: + explicit NaturalSortGroup(ClientContext &client); + + atomic count; + + unique_ptr columns; + atomic get_columns; +}; + +NaturalSortGroup::NaturalSortGroup(ClientContext &client) : count(0), get_columns(0) { +} + +//===--------------------------------------------------------------------===// +// NaturalSortGlobalSinkState +//===--------------------------------------------------------------------===// +class NaturalSortGlobalSinkState : public GlobalSinkState { +public: + using HashGroupPtr = unique_ptr; + + NaturalSortGlobalSinkState(ClientContext &client, const NaturalSort &natural_sort); + + ProgressData GetSinkProgress(ClientContext &context, const ProgressData source_progress) const; + + //! System and query state + const NaturalSort &natural_sort; + + //! Combined rows + mutable mutex lock; + HashGroupPtr hash_group; + + // Threading + atomic count; +}; + +NaturalSortGlobalSinkState::NaturalSortGlobalSinkState(ClientContext &client, const NaturalSort &natural_sort) + : natural_sort(natural_sort), count(0) { +} + +ProgressData NaturalSortGlobalSinkState::GetSinkProgress(ClientContext &client, const ProgressData source) const { + ProgressData result; + result.done = source.done / 2; + result.total = source.total; + result.invalid = source.invalid; + + return result; +} + +SinkFinalizeType NaturalSort::Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const { + auto &gsink = finalize.global_state.Cast(); + + // Did we get any data? + return gsink.count ? SinkFinalizeType::READY : SinkFinalizeType::NO_OUTPUT_POSSIBLE; +} + +ProgressData NaturalSort::GetSinkProgress(ClientContext &client, GlobalSinkState &gstate, + const ProgressData source) const { + auto &gsink = gstate.Cast(); + return gsink.GetSinkProgress(client, source); +} + +//===--------------------------------------------------------------------===// +// NaturalSortLocalSinkState +//===--------------------------------------------------------------------===// +class NaturalSortLocalSinkState : public LocalSinkState { +public: + NaturalSortLocalSinkState(ExecutionContext &context, const NaturalSort &natural_sort); + + //! Global state + const NaturalSort &natural_sort; + + //! Merge the state into the global state. + void Combine(ExecutionContext &context); + + // OVER() (no sorting) + unique_ptr unsorted; + ColumnDataAppendState unsorted_append; +}; + +NaturalSortLocalSinkState::NaturalSortLocalSinkState(ExecutionContext &context, const NaturalSort &natural_sort) + : natural_sort(natural_sort) { + unsorted = make_uniq(context.client, natural_sort.payload_types); + unsorted->InitializeAppend(unsorted_append); +} + +SinkResultType NaturalSort::Sink(ExecutionContext &context, DataChunk &input_chunk, OperatorSinkInput &sink) const { + auto &gstate = sink.global_state.Cast(); + auto &lstate = sink.local_state.Cast(); + gstate.count += input_chunk.size(); + + lstate.unsorted->Append(lstate.unsorted_append, input_chunk); + return SinkResultType::NEED_MORE_INPUT; +} + +SinkCombineResultType NaturalSort::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { + auto &gstate = combine.global_state.Cast(); + auto &lstate = combine.local_state.Cast(); + + // Only one partition, so need a global lock. + lock_guard glock(gstate.lock); + auto &hash_group = gstate.hash_group; + if (hash_group) { + auto &unsorted = *hash_group->columns; + if (lstate.unsorted) { + hash_group->count += lstate.unsorted->Count(); + unsorted.Combine(*lstate.unsorted); + lstate.unsorted.reset(); + } + } else { + hash_group = make_uniq(context.client); + hash_group->columns = std::move(lstate.unsorted); + hash_group->count += hash_group->columns->Count(); + } + return SinkCombineResultType::FINISHED; +} + +//===--------------------------------------------------------------------===// +// NaturalSortGlobalSourceState +//===--------------------------------------------------------------------===// +class NaturalSortGlobalSourceState : public GlobalSourceState { +public: + using ChunkRow = NaturalSort::ChunkRow; + using ChunkRows = NaturalSort::ChunkRows; + + NaturalSortGlobalSourceState(ClientContext &client, NaturalSortGlobalSinkState &gsink); + + NaturalSortGlobalSinkState &gsink; + ChunkRows chunk_rows; +}; + +NaturalSortGlobalSourceState::NaturalSortGlobalSourceState(ClientContext &client, NaturalSortGlobalSinkState &gsink) + : gsink(gsink) { + if (!gsink.count) { + return; + } + + // One unsorted group. We have the count and chunks. + ChunkRow chunk_row; + + auto &hash_group = gsink.hash_group; + if (hash_group) { + chunk_row.count = hash_group->count; + chunk_row.chunks = hash_group->columns->ChunkCount(); + } + + chunk_rows.emplace_back(chunk_row); +} + +//===--------------------------------------------------------------------===// +// NaturalSort +//===--------------------------------------------------------------------===// +NaturalSort::NaturalSort(const Types &input_types) : SortStrategy(input_types) { +} + +unique_ptr NaturalSort::GetGlobalSinkState(ClientContext &client) const { + return make_uniq(client, *this); +} + +unique_ptr NaturalSort::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(context, *this); +} + +unique_ptr NaturalSort::GetGlobalSourceState(ClientContext &client, GlobalSinkState &sink) const { + return make_uniq(client, sink.Cast()); +} + +unique_ptr NaturalSort::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(); +} + +const NaturalSort::ChunkRows &NaturalSort::GetHashGroups(GlobalSourceState &gstate) const { + auto &gsource = gstate.Cast(); + return gsource.chunk_rows; +} + +SourceResultType NaturalSort::MaterializeColumnData(ExecutionContext &execution, idx_t hash_bin, + OperatorSourceInput &source) const { + D_ASSERT(hash_bin == 0); + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_group; + + // Hack: Only report finished for the first call + return hash_group.get_columns++ ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +NaturalSort::HashGroupPtr NaturalSort::GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const { + auto &gsource = source.global_state.Cast(); + auto &gsink = gsource.gsink; + auto &hash_group = *gsink.hash_group; + + // OVER() + D_ASSERT(hash_bin == 0); + return std::move(hash_group.columns); +} + +SourceResultType NaturalSort::MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const { + throw InternalException("NaturalSort does not implement sorted runs."); +} + +NaturalSort::SortedRunPtr NaturalSort::GetSortedRun(ClientContext &client, idx_t hash_bin, + OperatorSourceInput &source) const { + throw InternalException("NaturalSort does not implement sorted runs."); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp deleted file mode 100644 index 2a0a65895..000000000 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ /dev/null @@ -1,671 +0,0 @@ -#include "duckdb/common/sort/partition_state.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/parallel/executor_task.hpp" - -namespace duckdb { - -PartitionGlobalHashGroup::PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, - const Orders &orders, const Types &payload_types, bool external) - : count(0) { - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - global_sort = make_uniq(context, orders, payload_layout); - global_sort->external = external; - - // Set up a comparator for the partition subset - partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); -} - -void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks) { - D_ASSERT(count > 0); - - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - - partition_mask.SetValidUnsafe(0); - unordered_map prefixes; - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(0); - D_ASSERT(order_mask.first >= partition_layout.column_count); - prefixes[order_mask.first] = global_sort->sort_layout.GetPrefixComparisonLayout(order_mask.first); - } - - for (++curr; curr.GetIndex() < count; ++curr) { - // Compare the partition subset first because if that differs, then so does the full ordering - const auto part_cmp = ComparePartitions(prev, curr); - - if (part_cmp) { - partition_mask.SetValidUnsafe(curr.GetIndex()); - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } else { - for (auto &order_mask : order_masks) { - if (prev.Compare(curr, prefixes[order_mask.first])) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } - } - ++prev; - } -} - -void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, - const Orders &order_bys, - const vector> &partition_stats) { - - // we sort by both 1) partition by expression list and 2) order by expressions - const auto partition_cols = partition_bys.size(); - for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { - auto &pexpr = partition_bys[prt_idx]; - - if (partition_stats.empty() || !partition_stats[prt_idx]) { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); - } else { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), - partition_stats[prt_idx]->ToUnique()); - } - partitions.emplace_back(orders.back().Copy()); - } - - for (const auto &order : order_bys) { - orders.emplace_back(order.Copy()); - } -} - -PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, - const vector> &partition_bys, - const vector &order_bys, - const Types &payload_types, - const vector> &partition_stats, - idx_t estimated_cardinality) - : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), - fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { - - GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).force_external; - - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); - while (max_bits < 10 && (thread_pages >> max_bits) > 1) { - ++max_bits; - } - - grouping_types_ptr = make_shared_ptr(); - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types_ptr->Initialize(payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - auto new_group = make_uniq(context, partitions, orders, payload_types, external); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - ResizeGroupingData(estimated_cardinality); - } - } -} - -bool PartitionGlobalSinkState::HasMergeTasks() const { - if (grouping_data) { - auto &groups = grouping_data->GetPartitions(); - return !groups.empty(); - } else if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - return hash_groups[0]->count > 0; - } else { - return false; - } -} - -void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { - fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; - - const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; - if (fixed_bits != old_bits) { - const auto hash_col_idx = payload_types.size(); - grouping_data = - make_uniq(buffer_manager, grouping_types_ptr, fixed_bits, hash_col_idx); - } -} - -unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { - const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types_ptr, new_bits, hash_col_idx); -} - -void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { - // Have we started to combine? Then just live with it. - if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { - return; - } - // Is the average partition size too large? - const idx_t partition_size = DEFAULT_ROW_GROUP_SIZE; - const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; - auto new_bits = bits ? bits : 4; - while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { - ++new_bits; - } - - // Repartition the grouping data - if (new_bits != bits) { - grouping_data = CreatePartition(new_bits); - } -} - -void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // We are done if the local_partition is right sized. - auto &local_radix = local_partition->Cast(); - const auto new_bits = grouping_data->GetRadixBits(); - if (local_radix.GetRadixBits() == new_bits) { - return; - } - - // If the local partition is now too small, flush it and reallocate - auto new_partition = CreatePartition(new_bits); - local_partition->FlushAppendState(*local_append); - local_partition->Repartition(context, *new_partition); - - local_partition = std::move(new_partition); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); -} - -void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // Make sure grouping_data doesn't change under us. - lock_guard guard(lock); - - if (!local_partition) { - local_partition = CreatePartition(grouping_data->GetRadixBits()); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); - return; - } - - // Grow the groups if they are too big - ResizeGroupingData(count); - - // Sync local partition to have the same bit count - SyncLocalPartition(local_partition, local_append); -} - -void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - if (!local_partition) { - return; - } - local_partition->FlushAppendState(*local_append); - - // Make sure grouping_data doesn't change under us. - // Combine has an internal mutex, so this is single-threaded anyway. - lock_guard guard(lock); - SyncLocalPartition(local_partition, local_append); - grouping_data->Combine(*local_partition); -} - -PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) - : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { - - // Set up the sort expression computation. - vector sort_types; - for (auto &order : gstate.orders) { - auto &oexpr = order.expression; - sort_types.emplace_back(oexpr->return_type); - executor.AddExpression(*oexpr); - } - sort_chunk.Initialize(gstate.allocator, sort_types); - payload_chunk.Initialize(gstate.allocator, gstate.payload_types); -} - -void PartitionLocalMergeState::Scan() { - if (!merge_state->group_data) { - // OVER(ORDER BY...) - // Already sorted - return; - } - - auto &group_data = *merge_state->group_data; - auto &hash_group = *merge_state->hash_group; - auto &chunk_state = merge_state->chunk_state; - // Copy the data from the group into the sort code. - auto &global_sort = *hash_group.global_sort; - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - TupleDataScanState local_scan; - group_data.InitializeScan(local_scan, merge_state->column_ids); - while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { - sort_chunk.Reset(); - executor.Execute(payload_chunk, sort_chunk); - - local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { - local_sort.Sort(global_sort, true); - } - hash_group.count += payload_chunk.size(); - } - - global_sort.AddLocalState(local_sort); -} - -// Per-thread sink state -PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { - - vector group_types; - for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { - auto &pexpr = *gstate.partitions[prt_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - sort_cols = gstate.orders.size() + group_types.size(); - - if (sort_cols) { - auto payload_types = gstate.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { - auto &pexpr = *gstate.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &global_sort = *gstate.hash_groups[0]->global_sort; - local_sort = make_uniq(); - local_sort->Initialize(global_sort, global_sort.buffer_manager); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - // OVER() - payload_layout.Initialize(gstate.payload_types); - } -} - -void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); - D_ASSERT(group_chunk.ColumnCount() > 0); - - // OVER(PARTITION BY...) (hash grouping) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - VectorOperations::Hash(group_chunk.data[0], hash_vector, count); - for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { - VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); - } -} - -void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { - gstate.count += input_chunk.size(); - - // OVER() - if (sort_cols == 0) { - // No sorts, so build paged row chunks - if (!rows) { - const auto entry_size = payload_layout.GetRowWidth(); - const auto block_size = gstate.buffer_manager.GetBlockSize(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, block_size / entry_size + 1); - rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, block_size, 1U, true); - } - const auto row_count = input_chunk.size(); - const auto row_sel = FlatVector::IncrementalSelectionVector(); - Vector addresses(LogicalType::POINTER); - auto key_locations = FlatVector::GetData(addresses); - const auto prev_rows_blocks = rows->blocks.size(); - auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); - auto input_data = input_chunk.ToUnifiedFormat(); - RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); - // Mark that row blocks contain pointers (heap blocks are pinned) - if (!payload_layout.AllConstant()) { - D_ASSERT(strings->keep_pinned); - for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { - rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); - } - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - local_sort->SinkChunk(group_chunk, input_chunk); - - auto &hash_group = *gstate.hash_groups[0]; - hash_group.count += input_chunk.size(); - - if (local_sort->SizeInBytes() > gstate.memory_per_thread) { - auto &global_sort = *hash_group.global_sort; - local_sort->Sort(global_sort, true); - } - return; - } - - // OVER(...) - payload_chunk.Reset(); - auto &hash_vector = payload_chunk.data.back(); - Hash(input_chunk, hash_vector); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { - payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); - } - payload_chunk.SetCardinality(input_chunk); - - gstate.UpdateLocalPartition(local_partition, local_append); - local_partition->Append(*local_append, payload_chunk); -} - -void PartitionLocalSinkState::Combine() { - // OVER() - if (sort_cols == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - if (gstate.rows) { - if (rows) { - gstate.rows->Merge(*rows); - gstate.strings->Merge(*strings); - rows.reset(); - strings.reset(); - } - } else { - gstate.rows = std::move(rows); - gstate.strings = std::move(strings); - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - auto &hash_group = *gstate.hash_groups[0]; - auto &global_sort = *hash_group.global_sort; - global_sort.AddLocalState(*local_sort); - local_sort.reset(); - return; - } - - // OVER(...) - gstate.CombineLocalPartition(local_partition, local_append); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, - hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), group_idx(sink.hash_groups.size()), - memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - auto new_group = make_uniq(sink.context, sink.partitions, sink.orders, sink.payload_types, - sink.external); - sink.hash_groups.emplace_back(std::move(new_group)); - - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; - - column_ids.reserve(sink.payload_types.size()); - for (column_t i = 0; i < sink.payload_types.size(); ++i) { - column_ids.emplace_back(i); - } - group_data->InitializeScan(chunk_state, column_ids); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), group_idx(0), memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const hash_t hash_bin = 0; - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; -} - -void PartitionLocalMergeState::Prepare() { - merge_state->group_data.reset(); - - auto &global_sort = *merge_state->global_sort; - global_sort.PrepareMergePhase(); -} - -void PartitionLocalMergeState::Merge() { - auto &global_sort = *merge_state->global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); -} - -void PartitionLocalMergeState::Sorted() { - merge_state->sink.OnSortedPartition(merge_state->group_idx); -} - -void PartitionLocalMergeState::ExecuteTask() { - switch (stage) { - case PartitionSortStage::SCAN: - Scan(); - break; - case PartitionSortStage::PREPARE: - Prepare(); - break; - case PartitionSortStage::MERGE: - Merge(); - break; - case PartitionSortStage::SORTED: - Sorted(); - break; - default: - throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); - } - - merge_state->CompleteTask(); - finished = true; -} - -bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { - lock_guard guard(lock); - - if (tasks_assigned >= total_tasks && !TryPrepareNextStage()) { - return false; - } - - local_state.merge_state = this; - local_state.stage = stage; - local_state.finished = false; - tasks_assigned++; - - return true; -} - -void PartitionGlobalMergeState::CompleteTask() { - lock_guard guard(lock); - - ++tasks_completed; -} - -bool PartitionGlobalMergeState::TryPrepareNextStage() { - if (tasks_completed < total_tasks) { - return false; - } - - tasks_assigned = tasks_completed = 0; - - switch (stage.load()) { - case PartitionSortStage::INIT: - // If the partitions are unordered, don't scan in parallel - // because it produces non-deterministic orderings. - // This can theoretically happen with ORDER BY, - // but that is something the query should be explicit about. - total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; - stage = PartitionSortStage::SCAN; - return true; - - case PartitionSortStage::SCAN: - total_tasks = 1; - stage = PartitionSortStage::PREPARE; - return true; - - case PartitionSortStage::PREPARE: - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - stage = PartitionSortStage::MERGE; - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::MERGE: - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::SORTED: - stage = PartitionSortStage::FINISHED; - total_tasks = 0; - return false; - - case PartitionSortStage::FINISHED: - return false; - } - - stage = PartitionSortStage::SORTED; - total_tasks = 1; - - return true; -} - -PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { - // Schedule all the sorts for maximum thread utilisation - if (sink.grouping_data) { - auto &partitions = sink.grouping_data->GetPartitions(); - sink.bin_groups.resize(partitions.size(), partitions.size()); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { - auto &group_data = partitions[hash_bin]; - // Prepare for merge sort phase - if (group_data->Count()) { - auto state = make_uniq(sink, std::move(group_data), hash_bin); - states.emplace_back(std::move(state)); - } - } - } else { - // OVER(ORDER BY...) - // Already sunk into the single global sort, so set up single merge with no data - sink.bin_groups.resize(1, 1); - auto state = make_uniq(sink); - states.emplace_back(std::move(state)); - } - - sink.OnBeginMerge(); -} - -class PartitionMergeTask : public ExecutorTask { -public: - PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate, const PhysicalOperator &op) - : ExecutorTask(context_p, std::move(event_p), op), local_state(gstate), hash_groups(hash_groups_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - - string TaskType() const override { - return "PartitionMergeTask"; - } - -private: - struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { - explicit ExecutorCallback(Executor &executor) : executor(executor) { - } - - bool HasError() const override { - return executor.HasError(); - } - - Executor &executor; - }; - - PartitionLocalMergeState local_state; - PartitionGlobalMergeStates &hash_groups; -}; - -bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { - // Loop until all hash groups are done - size_t sorted = 0; - while (sorted < states.size()) { - // First check if there is an unfinished task for this thread - if (callback.HasError()) { - return false; - } - if (!local_state.TaskFinished()) { - local_state.ExecuteTask(); - continue; - } - - // Thread is done with its assigned task, try to fetch new work - for (auto group = sorted; group < states.size(); ++group) { - auto &global_state = states[group]; - if (global_state->IsFinished()) { - // This hash group is done - // Update the high water mark of densely completed groups - if (sorted == group) { - ++sorted; - } - continue; - } - - // Try to assign work for this hash group to this thread - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // We were able to prepare the next merge round, - // but we were not able to assign a task for it to this thread - // The tasks were assigned to other threads while this thread waited for the lock - // Go to the next iteration to see if another hash group has a task - } - } - - return true; -} - -TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutorCallback callback(executor); - - if (!hash_groups.ExecuteTask(local_state, callback)) { - return TaskExecutionResult::TASK_ERROR; - } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void PartitionMergeEvent::Schedule() { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate, op)); - } - SetTasks(std::move(merge_tasks)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp deleted file mode 100644 index b193cee61..000000000 --- a/src/duckdb/src/common/sort/radix_sort.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/duckdb_pdqsort.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } - - if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } - - if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } - - const auto block_size = buffer_manager.GetBlockSize(); - auto temp_block = - buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); - auto pre_allocated_array = - make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - pre_allocated_array.get(), false); -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array_uninitialized(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sort.cpp b/src/duckdb/src/common/sort/sort.cpp similarity index 94% rename from src/duckdb/src/common/sorting/sort.cpp rename to src/duckdb/src/common/sort/sort.cpp index 2159878ff..b46db0a5f 100644 --- a/src/duckdb/src/common/sorting/sort.cpp +++ b/src/duckdb/src/common/sort/sort.cpp @@ -141,7 +141,7 @@ class SortLocalSinkState : public LocalSinkState { D_ASSERT(!sorted_run); // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART - sorted_run = make_uniq(context, sort.key_layout, sort.payload_layout, false); + sorted_run = make_uniq(context, sort, false); } public: @@ -161,7 +161,7 @@ class SortGlobalSinkState : public GlobalSinkState { public: explicit SortGlobalSinkState(ClientContext &context) : num_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), - temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), + temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), sorted_tuples(0), external(ClientConfig::GetConfig(context).force_external), any_combined(false), total_count(0), partition_size(0) { } @@ -366,8 +366,7 @@ ProgressData Sort::GetSinkProgress(ClientContext &context, GlobalSinkState &gsta class SortGlobalSourceState : public GlobalSourceState { public: SortGlobalSourceState(const Sort &sort, ClientContext &context, SortGlobalSinkState &sink_p) - : sink(sink_p), merger(*sort.decode_sort_key, sort.key_layout, std::move(sink.sorted_runs), - sort.output_projection_columns, sink.partition_size, sink.external, false), + : sink(sink_p), merger(sort, std::move(sink.sorted_runs), sink.partition_size, sink.external, false), merger_global_state(merger.total_count == 0 ? nullptr : merger.GetGlobalSourceState(context)) { // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART @@ -378,6 +377,15 @@ class SortGlobalSourceState : public GlobalSourceState { return merger_global_state ? merger_global_state->MaxThreads() : 1; } + void Destroy() { + if (!merger_global_state) { + return; + } + auto guard = merger_global_state->Lock(); + merger.sorted_runs.clear(); + sink.temporary_memory_state.reset(); + } + public: //! The global sink state SortGlobalSinkState &sink; @@ -456,7 +464,8 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator chunk.Initialize(context.client, types); // Initialize local output collection - auto local_column_data = make_uniq(context.client, types, true); + auto local_column_data = + make_uniq(context.client, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); while (true) { // Check for interrupts since this could be a long-running task @@ -477,16 +486,26 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator } // Merge into global output collection - auto guard = gstate.Lock(); - if (!gstate.column_data) { - gstate.column_data = std::move(local_column_data); - } else { - gstate.column_data->Merge(*local_column_data); + { + auto guard = gstate.Lock(); + if (!gstate.column_data) { + gstate.column_data = std::move(local_column_data); + } else { + gstate.column_data->Merge(*local_column_data); + } } + // Destroy local state before returning + input.local_state.Cast().merger_local_state.reset(); + // Return type indicates whether materialization is done const auto progress_data = GetProgress(context.client, input.global_state); - return progress_data.done == progress_data.total ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + if (progress_data.done == progress_data.total) { + // Destroy global state before returning + gstate.Destroy(); + return SourceResultType::FINISHED; + } + return SourceResultType::HAVE_MORE_OUTPUT; } unique_ptr Sort::GetColumnData(OperatorSourceInput &input) const { @@ -502,12 +521,15 @@ SourceResultType Sort::MaterializeSortedRun(ExecutionContext &context, OperatorS } auto &lstate = input.local_state.Cast(); OperatorSourceInput merger_input {*gstate.merger_global_state, *lstate.merger_local_state, input.interrupt_state}; - return gstate.merger.MaterializeMerge(context, merger_input); + return gstate.merger.MaterializeSortedRun(context, merger_input); } unique_ptr Sort::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); - return gstate.merger.GetMaterialized(gstate); + if (gstate.merger.total_count == 0) { + return nullptr; + } + return gstate.merger.GetSortedRun(*gstate.merger_global_state); } } // namespace duckdb diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp deleted file mode 100644 index 369f032f1..000000000 --- a/src/duckdb/src/common/sort/sort_state.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - case PhysicalType::ARRAY: - // Arrays get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, ArrayType::GetChildType(type)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - const auto block_size = buffer_manager->GetBlockSize(); - - // Radix sorting data - auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); - radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); - - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); - blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); - } - - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); - payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); - payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto &buffer_manager = row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager.Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager.Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, - unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(ClientContext &context_p, const vector &orders, - RowLayout &payload_layout) - : context(context_p), buffer_manager(BufferManager::GetBufferManager(context)), sort_layout(SortLayout(orders)), - payload_layout(payload_layout), block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sort_strategy.cpp b/src/duckdb/src/common/sort/sort_strategy.cpp new file mode 100644 index 000000000..123131318 --- /dev/null +++ b/src/duckdb/src/common/sort/sort_strategy.cpp @@ -0,0 +1,46 @@ +#include "duckdb/common/sorting/sort_strategy.hpp" +#include "duckdb/common/sorting/full_sort.hpp" +#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/natural_sort.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// SortStrategy +//===--------------------------------------------------------------------===// +SortStrategy::SortStrategy(const Types &input_types) : payload_types(input_types) { + // The payload prefix is the same as the input schema + for (column_t i = 0; i < payload_types.size(); ++i) { + scan_ids.emplace_back(i); + } +} + +void SortStrategy::Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const { +} + +void SortStrategy::SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize) { + // Nothing to sort + return; +} + +unique_ptr SortStrategy::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(); +} + +unique_ptr SortStrategy::Factory(ClientContext &client, + const vector> &partition_bys, + const vector &order_bys, const Types &payload_types, + const vector> &partitions_stats, + idx_t estimated_cardinality, bool require_payload) { + if (!partition_bys.empty()) { + return make_uniq(client, partition_bys, order_bys, payload_types, partitions_stats, + estimated_cardinality, require_payload); + } else if (!order_bys.empty()) { + return make_uniq(client, order_bys, payload_types, require_payload); + } else { + return make_uniq(payload_types); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp deleted file mode 100644 index c4766c956..000000000 --- a/src/duckdb/src/common/sort/sorted_block.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/common/sort/sorted_block.hpp" - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); - radix_sorting_data.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), external(gss.external), - cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sorted_run.cpp b/src/duckdb/src/common/sort/sorted_run.cpp similarity index 67% rename from src/duckdb/src/common/sorting/sorted_run.cpp rename to src/duckdb/src/common/sort/sorted_run.cpp index 57c390d32..a22f0a0b3 100644 --- a/src/duckdb/src/common/sorting/sorted_run.cpp +++ b/src/duckdb/src/common/sort/sorted_run.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -9,14 +10,145 @@ namespace duckdb { -SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort_p) - : context(context_p), - key_data(make_uniq(BufferManager::GetBufferManager(context), std::move(key_layout))), - payload_data( - payload_layout && payload_layout->ColumnCount() != 0 - ? make_uniq(BufferManager::GetBufferManager(context), std::move(payload_layout)) - : nullptr), +//===--------------------------------------------------------------------===// +// SortedRunScanState +//===--------------------------------------------------------------------===// +SortedRunScanState::SortedRunScanState(ClientContext &context, const Sort &sort_p) + : sort(sort_p), key_executor(context, *sort.decode_sort_key) { + key.Initialize(context, {sort.key_layout->GetTypes()[0]}); + decoded_key.Initialize(context, {sort.decode_sort_key->return_type}); +} + +void SortedRunScanState::Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + const auto sort_key_type = sort.key_layout->GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + default: + throw NotImplementedException("SortedRunMergerLocalState::ScanPartition for %s", + EnumUtil::ToString(sort_key_type)); + } +} + +template +void TemplatedGetKeyAndPayload(SORT_KEY *const *const sort_keys, SORT_KEY *temp_keys, const idx_t &count, + DataChunk &key, data_ptr_t *const payload_ptrs) { + const auto key_data = FlatVector::GetData(key.data[0]); + for (idx_t i = 0; i < count; i++) { + auto &sort_key = temp_keys[i]; + sort_key = *sort_keys[i]; + sort_key.Deconstruct(key_data[i]); + if (SORT_KEY::HAS_PAYLOAD) { + payload_ptrs[i] = sort_key.GetPayload(); + } + } + key.SetCardinality(count); +} + +template +void GetKeyAndPayload(SORT_KEY *const *const sort_keys, SORT_KEY *temp_keys, const idx_t &count, DataChunk &key, + data_ptr_t *const payload_ptrs) { + const auto type_id = key.data[0].GetType().id(); + switch (type_id) { + case LogicalTypeId::BLOB: + return TemplatedGetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + case LogicalTypeId::BIGINT: + return TemplatedGetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + default: + throw NotImplementedException("GetKeyAndPayload for %s", EnumUtil::ToString(type_id)); + } +} + +template +void SortedRunScanState::TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + using SORT_KEY = SortKey; + + const auto &output_projection_columns = sort.output_projection_columns; + idx_t opc_idx = 0; + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); + bool gathered_payload = false; + + // Decode from key + if (!output_projection_columns[0].is_payload) { + key.Reset(); + key_buffer.resize(count * sizeof(SORT_KEY)); + auto temp_keys = reinterpret_cast(key_buffer.data()); + GetKeyAndPayload(sort_keys, temp_keys, count, key, payload_ptrs); + + decoded_key.Reset(); + key_executor.Execute(key, decoded_key); + + const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + if (opc.is_payload) { + break; + } + chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); + } + + gathered_payload = true; + } + + // If there are no payload columns, we're done here + if (opc_idx != output_projection_columns.size()) { + if (!gathered_payload) { + // Gather row pointers from keys + for (idx_t i = 0; i < count; i++) { + payload_ptrs[i] = sort_keys[i]->GetPayload(); + } + } + + // Init scan state + auto &payload_data = *sorted_run.payload_data; + if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { + payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); + } + TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); + + // Now gather from payload + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + D_ASSERT(opc.is_payload); + payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), + count, opc.layout_col_idx, chunk.data[opc.output_col_idx], + *FlatVector::IncrementalSelectionVector(), + payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); + } + } + + chunk.SetCardinality(count); +} + +//===--------------------------------------------------------------------===// +// SortedRun +//===--------------------------------------------------------------------===// +SortedRun::SortedRun(ClientContext &context_p, const Sort &sort_p, bool is_index_sort_p) + : context(context_p), sort(sort_p), + key_data(make_uniq(context, sort.key_layout, MemoryTag::ORDER_BY)), + payload_data(sort.payload_layout && sort.payload_layout->ColumnCount() != 0 + ? make_uniq(context, sort.payload_layout, MemoryTag::ORDER_BY) + : nullptr), is_index_sort(is_index_sort_p), finalized(false) { key_data->InitializeAppend(key_append_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); if (payload_data) { @@ -25,8 +157,7 @@ SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_l } unique_ptr SortedRun::CreateRunForMaterialization() const { - auto res = make_uniq(context, key_data->GetLayoutPtr(), - payload_data ? payload_data->GetLayoutPtr() : nullptr, is_index_sort); + auto res = make_uniq(context, sort, is_index_sort); res->key_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->payload_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->finalized = true; diff --git a/src/duckdb/src/common/sorting/sorted_run_merger.cpp b/src/duckdb/src/common/sort/sorted_run_merger.cpp similarity index 87% rename from src/duckdb/src/common/sorting/sorted_run_merger.cpp rename to src/duckdb/src/common/sort/sorted_run_merger.cpp index eb879edc5..ee18bc734 100644 --- a/src/duckdb/src/common/sorting/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sort/sorted_run_merger.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/sorting/sorted_run_merger.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -100,7 +101,7 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Whether this thread has finished the work it has been assigned bool TaskFinished() const; //! Do the work this thread has been assigned - void ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); + SourceResultType ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); private: //! Computes upper partition boundaries using K-way Merge Path @@ -154,12 +155,10 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Variables for scanning idx_t merged_partition_count; idx_t merged_partition_index; - TupleDataScanState payload_state; - //! For decoding sort keys - ExpressionExecutor key_executor; - DataChunk key; - DataChunk decoded_key; + //! For scanning + Vector sort_key_pointers; + SortedRunScanState sorted_run_scan_state; }; //===--------------------------------------------------------------------===// @@ -172,7 +171,7 @@ class SortedRunMergerGlobalState : public GlobalSourceState { merger(merger_p), num_runs(merger.sorted_runs.size()), num_partitions((merger.total_count + (merger.partition_size - 1)) / merger.partition_size), iterator_state_type(GetBlockIteratorStateType(merger.external)), - sort_key_type(merger.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), + sort_key_type(merger.sort.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), destroy_partition_idx(0) { // Initialize partitions partitions.resize(num_partitions); @@ -263,6 +262,11 @@ class SortedRunMergerGlobalState : public GlobalSourceState { destroy_partition_idx = end_partition_idx; } +private: + static BlockIteratorStateType GetBlockIteratorStateType(const bool &external) { + return external ? BlockIteratorStateType::EXTERNAL : BlockIteratorStateType::IN_MEMORY; + } + public: ClientContext &context; const idx_t num_threads; @@ -292,7 +296,7 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState : iterator_state_type(gstate.iterator_state_type), sort_key_type(gstate.sort_key_type), task(SortedRunMergerTask::FINISHED), run_boundaries(gstate.num_runs), merged_partition_count(DConstants::INVALID_INDEX), merged_partition_index(DConstants::INVALID_INDEX), - key_executor(gstate.context, gstate.merger.decode_sort_key) { + sorted_run_scan_state(gstate.context, gstate.merger.sort), sort_key_pointers(LogicalType::POINTER) { for (const auto &run : gstate.merger.sorted_runs) { auto &key_data = *run->key_data; switch (iterator_state_type) { @@ -308,8 +312,6 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState EnumUtil::ToString(iterator_state_type)); } } - key.Initialize(gstate.context, {gstate.merger.key_layout->GetTypes()[0]}); - decoded_key.Initialize(gstate.context, {gstate.merger.decode_sort_key.return_type}); } bool SortedRunMergerLocalState::TaskFinished() const { @@ -328,7 +330,8 @@ bool SortedRunMergerLocalState::TaskFinished() const { } } -void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk) { +SourceResultType SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, + optional_ptr chunk) { D_ASSERT(task != SortedRunMergerTask::FINISHED); switch (task) { case SortedRunMergerTask::COMPUTE_BOUNDARIES: @@ -352,14 +355,20 @@ void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, if (!chunk || chunk->size() == 0) { gstate.DestroyScannedData(); gstate.partitions[partition_idx.GetIndex()]->scanned = true; - gstate.total_scanned += merged_partition_count; + // fetch_add returns the _previous_ value! + const auto scan_count_before_adding = gstate.total_scanned.fetch_add(merged_partition_count); + const auto scan_count_after_adding = scan_count_before_adding + merged_partition_count; partition_idx = optional_idx::Invalid(); task = SortedRunMergerTask::FINISHED; + if (scan_count_after_adding == gstate.merger.total_count) { + return SourceResultType::FINISHED; + } } break; default: throw NotImplementedException("SortedRunMergerLocalState::ExecuteTask for task"); } + return SourceResultType::HAVE_MORE_OUTPUT; } void SortedRunMergerLocalState::ComputePartitionBoundaries(SortedRunMergerGlobalState &gstate, @@ -685,94 +694,21 @@ void SortedRunMergerLocalState::ScanPartition(SortedRunMergerGlobalState &gstate } } -template -void TemplatedGetKeyAndPayload(SORT_KEY *const merged_partition_keys, const idx_t count, DataChunk &key, - data_ptr_t *const payload_ptrs) { - const auto key_data = FlatVector::GetData(key.data[0]); - for (idx_t i = 0; i < count; i++) { - auto &merged_partition_key = merged_partition_keys[i]; - merged_partition_key.Deconstruct(key_data[i]); - if (SORT_KEY::HAS_PAYLOAD) { - payload_ptrs[i] = merged_partition_key.GetPayload(); - } - } - key.SetCardinality(count); -} - -template -void GetKeyAndPayload(SORT_KEY *const merged_partition_keys, const idx_t count, DataChunk &key, - data_ptr_t *const payload_ptrs) { - const auto type_id = key.data[0].GetType().id(); - switch (type_id) { - case LogicalTypeId::BLOB: - return TemplatedGetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - case LogicalTypeId::BIGINT: - return TemplatedGetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - default: - throw NotImplementedException("GetKeyAndPayload for %s", EnumUtil::ToString(type_id)); - } -} - template void SortedRunMergerLocalState::TemplatedScanPartition(SortedRunMergerGlobalState &gstate, DataChunk &chunk) { using SORT_KEY = SortKey; const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - const auto &output_projection_columns = gstate.merger.output_projection_columns; - idx_t opc_idx = 0; - + // Grab pointers to sort keys const auto merged_partition_keys = reinterpret_cast(merged_partition.get()) + merged_partition_index; - const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); - bool gathered_payload = false; - - // Decode from key - if (!output_projection_columns[0].is_payload) { - key.Reset(); - GetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - - decoded_key.Reset(); - key_executor.Execute(key, decoded_key); - - const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - if (opc.is_payload) { - break; - } - chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); - } - gathered_payload = true; - } - - // If there are no payload columns, we're done here - if (opc_idx != output_projection_columns.size()) { - if (!gathered_payload) { - // Gather row pointers from keys - for (idx_t i = 0; i < count; i++) { - payload_ptrs[i] = merged_partition_keys[i].GetPayload(); - } - } - - // Init scan state - auto &payload_data = *gstate.merger.sorted_runs.back()->payload_data; - if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { - payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); - } - TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); - - // Now gather from payload - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - D_ASSERT(opc.is_payload); - payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), - count, opc.layout_col_idx, chunk.data[opc.output_col_idx], - *FlatVector::IncrementalSelectionVector(), - payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); - } + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < count; i++) { + sort_keys[i] = &merged_partition_keys[i]; } - merged_partition_index += count; - chunk.SetCardinality(count); + + // Scan + sorted_run_scan_state.Scan(*gstate.merger.sorted_runs[0], sort_key_pointers, count, chunk); } void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState &gstate) { @@ -812,7 +748,9 @@ void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState // Add to global state lock_guard guard(gstate.materialized_partition_lock); - gstate.materialized_partitions.resize(partition_idx.GetIndex()); + if (gstate.materialized_partitions.size() < partition_idx.GetIndex() + 1) { + gstate.materialized_partitions.resize(partition_idx.GetIndex() + 1); + } gstate.materialized_partitions[partition_idx.GetIndex()] = std::move(sorted_run); } @@ -833,7 +771,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S while (merged_partition_index < merged_partition_count) { const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count + count; i++) { + for (idx_t i = 0; i < count; i++) { auto &key = merged_partition_keys[merged_partition_index + i]; key_locations[i] = data_ptr_cast(&key); if (!SORT_KEY::CONSTANT_SIZE) { @@ -855,7 +793,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S if (!sorted_run->payload_data->GetLayout().AllConstant()) { sorted_run->payload_data->FindHeapPointers(payload_data_input, count); } - sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(key_data_input.heap_sizes); + sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(payload_data_input.heap_sizes); sorted_run->payload_data->Build(sorted_run->payload_append_state.pin_state, sorted_run->payload_append_state.chunk_state, 0, count); sorted_run->payload_data->CopyRows(sorted_run->payload_append_state.chunk_state, payload_data_input, @@ -876,18 +814,16 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S //===--------------------------------------------------------------------===// // Sorted Run Merger //===--------------------------------------------------------------------===// -SortedRunMerger::SortedRunMerger(const Expression &decode_sort_key_p, shared_ptr key_layout_p, - vector> &&sorted_runs_p, - const vector &output_projection_columns_p, +SortedRunMerger::SortedRunMerger(const Sort &sort_p, vector> &&sorted_runs_p, idx_t partition_size_p, bool external_p, bool is_index_sort_p) - : decode_sort_key(decode_sort_key_p), key_layout(std::move(key_layout_p)), sorted_runs(std::move(sorted_runs_p)), - output_projection_columns(output_projection_columns_p), total_count(SortedRunsTotalCount(sorted_runs)), + : sort(sort_p), sorted_runs(std::move(sorted_runs_p)), total_count(SortedRunsTotalCount(sorted_runs)), partition_size(partition_size_p), external(external_p), is_index_sort(is_index_sort_p) { } unique_ptr SortedRunMerger::GetLocalSourceState(ExecutionContext &, GlobalSourceState &gstate_p) const { auto &gstate = gstate_p.Cast(); + auto guard = gstate.Lock(); return make_uniq(gstate); } @@ -929,30 +865,28 @@ ProgressData SortedRunMerger::GetProgress(ClientContext &, GlobalSourceState &gs //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// -SourceResultType SortedRunMerger::MaterializeMerge(ExecutionContext &, OperatorSourceInput &input) const { +SourceResultType SortedRunMerger::MaterializeSortedRun(ExecutionContext &, OperatorSourceInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); + SourceResultType res = SourceResultType::HAVE_MORE_OUTPUT; while (true) { if (!lstate.TaskFinished() || gstate.AssignTask(lstate)) { - lstate.ExecuteTask(gstate, nullptr); + res = lstate.ExecuteTask(gstate, nullptr); } else { break; } } - if (gstate.total_scanned == total_count) { - // This signals that the data has been fully materialized - return SourceResultType::FINISHED; - } - // This signals that no more tasks are left, but that the data has not yet been fully materialized - return SourceResultType::HAVE_MORE_OUTPUT; + // The thread that completes the materialization returns FINISHED, all other threads return HAVE_MORE_OUTPUT + return res; } -unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global_state) { +unique_ptr SortedRunMerger::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); + D_ASSERT(total_count != 0); + lock_guard guard(gstate.materialized_partition_lock); if (gstate.materialized_partitions.empty()) { - D_ASSERT(total_count == 0); return nullptr; } auto &target = *gstate.materialized_partitions[0]; @@ -963,7 +897,9 @@ unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global target.payload_data->Combine(*source.payload_data); } } - return std::move(gstate.materialized_partitions[0]); + auto res = std::move(gstate.materialized_partitions[0]); + gstate.materialized_partitions.clear(); + return res; } } // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index 51be7c3eb..9845961bb 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -10,6 +10,7 @@ #include "duckdb/original/std/sstream.hpp" #include "jaro_winkler.hpp" #include "utf8proc_wrapper.hpp" +#include "duckdb/common/types/string_type.hpp" #include #include @@ -34,6 +35,27 @@ string StringUtil::GenerateRandomName(idx_t length) { return ss.str(); } +bool StringUtil::Equals(const string_t &s1, const char *s2) { + auto s1_data = s1.GetData(); + for (idx_t i = 0; i < s1.GetSize(); i++) { + if (s1_data[i] != s2[i]) { + return false; + } + if (s2[i] == '\0') { + return false; + } + } + if (s2[s1.GetSize()] != '\0') { + // not equal + return false; + } + return true; +} + +bool StringUtil::Equals(const char *s1, const string_t &s2) { + return StringUtil::Equals(s2, s1); +} + bool StringUtil::Contains(const string &haystack, const string &needle) { return Find(haystack, needle).IsValid(); } @@ -287,9 +309,13 @@ bool StringUtil::IsUpper(const string &str) { // Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function uint64_t StringUtil::CIHash(const string &str) { + return StringUtil::CIHash(str.c_str(), str.size()); +} + +uint64_t StringUtil::CIHash(const char *str, idx_t size) { uint32_t hash = 0; - for (auto c : str) { - hash += static_cast(StringUtil::CharacterToLower(static_cast(c))); + for (idx_t i = 0; i < size; i++) { + hash += static_cast(StringUtil::CharacterToLower(static_cast(str[i]))); hash += hash << 10; hash ^= hash >> 6; } @@ -396,7 +422,10 @@ vector StringUtil::TopNStrings(vector> scores, idx_ return vector(); } sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { - return a.second > b.second || (a.second == b.second && a.first.size() < b.first.size()); + if (a.second != b.second) { + return a.second > b.second; + } + return StringUtil::CILessThan(a.first, b.first); }); vector result; result.push_back(scores[0].first); @@ -702,6 +731,21 @@ string StringUtil::ToComplexJSONMap(const ComplexJSON &complex_json) { return ComplexJSON::GetValueRecursive(complex_json); } +string StringUtil::ValidateJSON(const char *data, const idx_t &len) { + // Same flags as in JSON extension + static constexpr auto READ_FLAG = + YYJSON_READ_ALLOW_INF_AND_NAN | YYJSON_READ_ALLOW_TRAILING_COMMAS | YYJSON_READ_BIGNUM_AS_RAW; + yyjson_read_err error; + yyjson_doc *doc = yyjson_read_opts((char *)data, len, READ_FLAG, nullptr, &error); // NOLINT: for yyjson + if (error.code != YYJSON_READ_SUCCESS) { + return StringUtil::Format("Malformed JSON at byte %lld of input: %s. Input: \"%s\"", error.pos, error.msg, + string(data, len)); + } + + yyjson_doc_free(doc); + return string(); +} + string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, const unordered_map &map) { D_ASSERT(map.find("exception_type") == map.end()); @@ -719,7 +763,6 @@ string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, } string StringUtil::GetFileName(const string &file_path) { - idx_t pos = file_path.find_last_of("/\\"); if (pos == string::npos) { return file_path; diff --git a/src/duckdb/src/common/thread_util.cpp b/src/duckdb/src/common/thread_util.cpp new file mode 100644 index 000000000..0860d6d2a --- /dev/null +++ b/src/duckdb/src/common/thread_util.cpp @@ -0,0 +1,14 @@ +#include "duckdb/common/thread.hpp" +#include "duckdb/common/chrono.hpp" + +namespace duckdb { + +void ThreadUtil::SleepMs(idx_t sleep_ms) { +#ifndef DUCKDB_NO_THREADS + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); +#else + throw InvalidInputException("ThreadUtil::SleepMs requires DuckDB to be compiled with thread support"); +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer.cpp index c8d97959b..c7a810468 100644 --- a/src/duckdb/src/common/tree_renderer.cpp +++ b/src/duckdb/src/common/tree_renderer.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/tree_renderer/html_tree_renderer.hpp" #include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" #include "duckdb/common/tree_renderer/yaml_tree_renderer.hpp" +#include "duckdb/common/tree_renderer/mermaid_tree_renderer.hpp" #include @@ -22,6 +23,8 @@ unique_ptr TreeRenderer::CreateRenderer(ExplainFormat format) { return make_uniq(); case ExplainFormat::YAML: return make_uniq(); + case ExplainFormat::MERMAID: + return make_uniq(); default: throw NotImplementedException("ExplainFormat %s not implemented", EnumUtil::ToString(format)); } diff --git a/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp new file mode 100644 index 000000000..9ff7b6539 --- /dev/null +++ b/src/duckdb/src/common/tree_renderer/mermaid_tree_renderer.cpp @@ -0,0 +1,133 @@ +#include "duckdb/common/tree_renderer/mermaid_tree_renderer.hpp" + +#include "duckdb/common/pair.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "utf8proc_wrapper.hpp" + +#include + +namespace duckdb { + +string MermaidTreeRenderer::ToString(const LogicalOperator &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const PhysicalOperator &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const ProfilingNode &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +string MermaidTreeRenderer::ToString(const Pipeline &op) { + duckdb::stringstream ss; + Render(op, ss); + return ss.str(); +} + +void MermaidTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +void MermaidTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { + auto tree = RenderTree::CreateRenderTree(op); + ToStream(*tree, ss); +} + +static string SanitizeMermaidLabel(const string &text) { + string result; + result.reserve(text.size() * 2); // Reserve more space for potential escape sequences + for (size_t i = 0; i < text.size(); i++) { + char c = text[i]; + // Escape backticks and quotes + if (c == '`') { + result += "\\`"; + } else if (c == '"') { + result += "\\\""; + } else if (c == '\\' && i + 1 < text.size() && text[i + 1] == 'n') { + // Replace literal "\n" with actual newline for Mermaid markdown + result += "\n\t"; + i++; // Skip the 'n' + } else { + result += c; + } + } + return result; +} + +void MermaidTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { + vector nodes; + vector edges; + + const string node_format = " node_%d_%d[\"`**%s**%s`\"]"; + + for (idx_t y = 0; y < root.height; y++) { + for (idx_t x = 0; x < root.width; x++) { + auto node = root.GetNode(x, y); + if (!node) { + continue; + } + + // Build node label with markdown formatting + string extra_info; + for (auto &item : node->extra_text) { + auto &key = item.first; + auto &value_raw = item.second; + + auto value = QueryProfiler::JSONSanitize(value_raw); + // Add newline and key-value pair + extra_info += StringUtil::Format("\n\t%s: %s", key, SanitizeMermaidLabel(value)); + } + + // Create node with bold operator name and extra info (trim name to remove trailing spaces) + auto trimmed_name = node->name; + StringUtil::Trim(trimmed_name); + nodes.push_back(StringUtil::Format(node_format, x, y, SanitizeMermaidLabel(trimmed_name), extra_info)); + + // Create Edge(s) + for (auto &coord : node->child_positions) { + edges.push_back(StringUtil::Format(" node_%d_%d --> node_%d_%d", x, y, coord.x, coord.y)); + } + } + } + + // Output Mermaid flowchart + ss << "flowchart TD\n"; + + // Output nodes + for (auto &node : nodes) { + ss << node << "\n\n"; + } + + // Output edges + for (auto &edge : edges) { + ss << edge << "\n"; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp index 251736dd4..09dcb0356 100644 --- a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp +++ b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp @@ -491,7 +491,7 @@ void TextTreeRenderer::SplitUpExtraInfo(const InsertionOrderPreservingMap max_lines) { diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 40ff794e6..5d4f745ea 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -31,6 +31,9 @@ namespace duckdb { +constexpr idx_t ArrayType::MAX_ARRAY_SIZE; +const idx_t UnionType::MAX_UNION_MEMBERS; + LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { } @@ -159,6 +162,8 @@ PhysicalType LogicalType::GetInternalType() { return PhysicalType::UNKNOWN; case LogicalTypeId::AGGREGATE_STATE: return PhysicalType::VARCHAR; + case LogicalTypeId::GEOMETRY: + return PhysicalType::VARCHAR; default: throw InternalException("Invalid LogicalType %s", ToString()); } @@ -806,6 +811,7 @@ bool LogicalType::SupportsRegularUpdate() const { case LogicalTypeId::ARRAY: case LogicalTypeId::MAP: case LogicalTypeId::UNION: + case LogicalTypeId::VARIANT: return false; case LogicalTypeId::STRUCT: { auto &child_types = StructType::GetChildTypes(*this); @@ -1344,6 +1350,8 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 102; case LogicalTypeId::BIGNUM: return 103; + case LogicalTypeId::GEOMETRY: + return 104; // nested types case LogicalTypeId::STRUCT: return 125; @@ -2014,6 +2022,15 @@ LogicalType LogicalType::VARIANT() { return LogicalType(LogicalTypeId::VARIANT, std::move(info)); } +//===--------------------------------------------------------------------===// +// Spatial Types +//===--------------------------------------------------------------------===// + +LogicalType LogicalType::GEOMETRY() { + auto info = make_shared_ptr(); + return LogicalType(LogicalTypeId::GEOMETRY, std::move(info)); +} + //===--------------------------------------------------------------------===// // Logical Type //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/batched_data_collection.cpp b/src/duckdb/src/common/types/batched_data_collection.cpp index fd25dbc1a..6f38c098c 100644 --- a/src/duckdb/src/common/types/batched_data_collection.cpp +++ b/src/duckdb/src/common/types/batched_data_collection.cpp @@ -2,18 +2,47 @@ #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p) { + ColumnDataAllocatorType allocator_type_p, + ColumnDataCollectionLifetime lifetime_p) + : context(context_p), types(std::move(types_p)), allocator_type(allocator_type_p), lifetime(lifetime_p) { +} + +BatchedDataCollection::BatchedDataCollection(ClientContext &context, vector types, + QueryResultMemoryType memory_type) + : BatchedDataCollection(context, std::move(types), + memory_type == QueryResultMemoryType::BUFFER_MANAGED + ? ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR + : ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + memory_type == QueryResultMemoryType::BUFFER_MANAGED + ? ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES + : ColumnDataCollectionLifetime::REGULAR) { } BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, batch_map_t batches, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p), data(std::move(batches)) { + ColumnDataAllocatorType allocator_type_p, + ColumnDataCollectionLifetime lifetime_p) + : context(context_p), types(std::move(types_p)), allocator_type(allocator_type_p), lifetime(lifetime_p), + data(std::move(batches)) { +} + +unique_ptr BatchedDataCollection::CreateCollection() const { + if (last_collection.collection) { + return make_uniq(*last_collection.collection); + } else if (allocator_type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { + auto &buffer_manager = lifetime == ColumnDataCollectionLifetime::REGULAR + ? BufferManager::GetBufferManager(context) + : BufferManager::GetBufferManager(*context.db); + return make_uniq(buffer_manager, types, lifetime); + } else { + D_ASSERT(allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + return make_uniq(Allocator::DefaultAllocator(), types); + } } void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { @@ -25,14 +54,7 @@ void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { } else { // new collection: check if there is already an entry D_ASSERT(data.find(batch_index) == data.end()); - unique_ptr new_collection; - if (last_collection.collection) { - new_collection = make_uniq(*last_collection.collection); - } else if (buffer_managed) { - new_collection = make_uniq(BufferManager::GetBufferManager(context), types); - } else { - new_collection = make_uniq(Allocator::DefaultAllocator(), types); - } + unique_ptr new_collection = CreateCollection(); last_collection.collection = new_collection.get(); last_collection.batch_index = batch_index; new_collection->InitializeAppend(last_collection.append_state); @@ -98,7 +120,7 @@ unique_ptr BatchedDataCollection::FetchCollection() { data.clear(); if (!result) { // empty result - return make_uniq(Allocator::DefaultAllocator(), types); + return CreateCollection(); } return result; } diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp index b4f3f4d74..b0fefb32e 100644 --- a/src/duckdb/src/common/types/column/column_data_allocator.cpp +++ b/src/duckdb/src/common/types/column/column_data_allocator.cpp @@ -2,6 +2,8 @@ #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/types/column/column_data_collection_segment.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/result_set_manager.hpp" #include "duckdb/storage/buffer/block_handle.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "duckdb/storage/buffer_manager.hpp" @@ -12,17 +14,24 @@ ColumnDataAllocator::ColumnDataAllocator(Allocator &allocator) : type(ColumnData alloc.allocator = &allocator; } -ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager) +ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager, ColumnDataCollectionLifetime lifetime) : type(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { alloc.buffer_manager = &buffer_manager; + if (lifetime == ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES) { + managed_result_set = ResultSetManager::Get(buffer_manager.GetDatabase()).Add(*this); + } } -ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type) +ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type, + ColumnDataCollectionLifetime lifetime) : type(allocator_type) { switch (type) { case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: case ColumnDataAllocatorType::HYBRID: alloc.buffer_manager = &BufferManager::GetBufferManager(context); + if (lifetime == ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES) { + managed_result_set = ResultSetManager::Get(context).Add(*this); + } break; case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: alloc.allocator = &Allocator::Get(context); @@ -38,6 +47,9 @@ ColumnDataAllocator::ColumnDataAllocator(ColumnDataAllocator &other) { case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: case ColumnDataAllocatorType::HYBRID: alloc.buffer_manager = other.alloc.buffer_manager; + if (other.managed_result_set.IsValid()) { + ResultSetManager::Get(alloc.buffer_manager->GetDatabase()).Add(*this); + } break; case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: alloc.allocator = other.alloc.allocator; @@ -51,8 +63,16 @@ ColumnDataAllocator::~ColumnDataAllocator() { if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { return; } + if (managed_result_set.IsValid()) { + D_ASSERT(type != ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); + auto db = managed_result_set.GetDatabase(); + if (db) { + ResultSetManager::Get(*db).Remove(*this); + } + return; + } for (auto &block : blocks) { - block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + block.GetHandle()->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); } blocks.clear(); } @@ -64,9 +84,9 @@ BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { // we only need to grab the lock when accessing the vector, because vector access is not thread-safe: // the vector can be resized by another thread while we try to access it lock_guard guard(lock); - handle = blocks[block_id].handle; + handle = blocks[block_id].GetHandle(); } else { - handle = blocks[block_id].handle; + handle = blocks[block_id].GetHandle(); } return alloc.buffer_manager->Pin(handle); } @@ -78,10 +98,10 @@ BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { data.size = 0; data.capacity = NumericCast(max_size); auto pin = alloc.buffer_manager->Allocate(MemoryTag::COLUMN_DATA, max_size, false); - data.handle = pin.GetBlockHandle(); + data.SetHandle(managed_result_set, pin.GetBlockHandle()); blocks.push_back(std::move(data)); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits - blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); + blocks.back().GetHandle()->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); } allocated_size += max_size; return pin; @@ -98,7 +118,6 @@ void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { BlockMetaData data; data.size = 0; data.capacity = NumericCast(allocation_amount); - data.handle = nullptr; blocks.push_back(std::move(data)); allocated_size += allocation_amount; } @@ -131,7 +150,8 @@ void ColumnDataAllocator::AllocateBuffer(idx_t size, uint32_t &block_id, uint32_ block_id = NumericCast(blocks.size() - 1); if (chunk_state && chunk_state->handles.find(block_id) == chunk_state->handles.end()) { // not guaranteed to be pinned already by this thread (if shared allocator) - chunk_state->handles[block_id] = alloc.buffer_manager->Pin(blocks[block_id].handle); + auto handle = blocks[block_id].GetHandle(); + chunk_state->handles[block_id] = alloc.buffer_manager->Pin(handle); } offset = block.size; block.size += size; @@ -235,7 +255,18 @@ void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector } void ColumnDataAllocator::SetDestroyBufferUponUnpin(uint32_t block_id) { - blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); + blocks[block_id].GetHandle()->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); +} + +shared_ptr ColumnDataAllocator::GetDatabase() const { + if (!managed_result_set.IsValid()) { + return nullptr; + } + auto db = managed_result_set.GetDatabase(); + if (!db) { + throw ConnectionException("Trying to access a query result after the database instance has been closed"); + } + return db; } Allocator &ColumnDataAllocator::GetAllocator() { @@ -282,6 +313,26 @@ void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, Chun } } +shared_ptr BlockMetaData::GetHandle() const { + if (handle) { + return handle; + } + auto res = weak_handle.lock(); + if (!res) { + throw ConnectionException("Trying to access a query result after the database instance has been closed"); + } + return res; +} + +void BlockMetaData::SetHandle(ManagedResultSet &managed_result_set, shared_ptr handle_p) { + if (managed_result_set.IsValid()) { + managed_result_set.GetHandles().emplace_back(handle_p); + weak_handle = handle_p; + } else { + handle = std::move(handle_p); + } +} + uint32_t BlockMetaData::Capacity() { D_ASSERT(size <= capacity); return capacity - size; diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp index b53e07d68..d9d69b151 100644 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/types/value_map.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/main/database.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { @@ -59,9 +60,10 @@ ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector(allocator_p); } -ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { +ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p, + ColumnDataCollectionLifetime lifetime) { Initialize(std::move(types_p)); - allocator = make_shared_ptr(buffer_manager); + allocator = make_shared_ptr(buffer_manager, lifetime); } ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { @@ -70,8 +72,8 @@ ColumnDataCollection::ColumnDataCollection(shared_ptr alloc } ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, - ColumnDataAllocatorType type) - : ColumnDataCollection(make_shared_ptr(context, type), std::move(types_p)) { + ColumnDataAllocatorType type, ColumnDataCollectionLifetime lifetime) + : ColumnDataCollection(make_shared_ptr(context, type, lifetime), std::move(types_p)) { D_ASSERT(!types.empty()); } @@ -146,16 +148,22 @@ idx_t ColumnDataRow::RowIndex() const { //===--------------------------------------------------------------------===// // ColumnDataRowCollection //===--------------------------------------------------------------------===// -ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection) { +ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection, + const ColumnDataScanProperties properties) { if (collection.Count() == 0) { return; } // read all the chunks ColumnDataScanState temp_scan_state; - collection.InitializeScan(temp_scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); + collection.InitializeScan(temp_scan_state, properties); while (true) { auto chunk = make_uniq(); - collection.InitializeScanChunk(*chunk); + // Use default allocator so the chunk is independently usable even after the DB allocator is destroyed + if (properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { + collection.InitializeScanChunk(Allocator::DefaultAllocator(), *chunk); + } else { + collection.InitializeScanChunk(*chunk); + } if (!collection.Scan(temp_scan_state, *chunk)) { break; } @@ -252,12 +260,13 @@ ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataColle : collection(collection_p) { } -ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) +ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p, + ColumnDataScanProperties properties) : collection(collection_p), scan_chunk(make_shared_ptr()), current_row(*scan_chunk, 0, 0) { if (!collection) { return; } - collection->InitializeScan(scan_state); + collection->InitializeScan(scan_state, properties); collection->InitializeScanChunk(*scan_chunk); collection->Scan(scan_state, *scan_chunk); } @@ -593,7 +602,6 @@ bool ColumnDataCopyCompressedStrings(ColumnDataMetaData &meta_data, const Vector template <> void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - const auto &allocator_type = meta_data.segment.allocator->GetType(); if (allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || allocator_type == ColumnDataAllocatorType::HYBRID) { @@ -733,7 +741,6 @@ void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVector template <> void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; auto &child_vector = ListVector::GetEntry(source); @@ -813,7 +820,6 @@ void ColumnDataCopyStruct(ColumnDataMetaData &meta_data, const UnifiedVectorForm void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; // copy the NULL values for the main array vector (the same as for a struct vector) @@ -842,7 +848,8 @@ void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorForma child_vector.ToUnifiedFormat(copy_count * array_size, child_vector_data); // Broadcast and sync the validity of the array vector to the child vector - + // This requires creating a copy of the validity mask: we cannot modify the input validity + child_vector_data.validity = ValidityMask(child_vector_data.validity, child_vector_data.validity.Capacity()); if (source_data.validity.IsMaskSet()) { for (idx_t i = 0; i < copy_count; i++) { auto source_idx = source_data.sel->get_index(offset + i); @@ -1015,6 +1022,7 @@ void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, ColumnData void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, vector column_ids, ColumnDataScanProperties properties) const { + state.db = allocator->GetDatabase(); state.chunk_index = 0; state.segment_index = 0; state.current_row_index = 0; @@ -1036,6 +1044,7 @@ void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, ve bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, DataChunk &result) const { + D_ASSERT(result.GetTypes() == types); result.Reset(); idx_t chunk_index; @@ -1052,7 +1061,11 @@ bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLo } void ColumnDataCollection::InitializeScanChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), types); + InitializeScanChunk(allocator->GetAllocator(), chunk); +} + +void ColumnDataCollection::InitializeScanChunk(Allocator &allocator, DataChunk &chunk) const { + chunk.Initialize(allocator, types); } void ColumnDataCollection::InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const { @@ -1129,6 +1142,10 @@ void ColumnDataCollection::ScanAtIndex(ColumnDataParallelScanState &state, Colum } bool ColumnDataCollection::Scan(ColumnDataScanState &state, DataChunk &result) const { + for (idx_t i = 0; i < state.column_ids.size(); i++) { + D_ASSERT(result.GetTypes()[i] == types[state.column_ids[i]]); + } + result.Reset(); idx_t chunk_index; @@ -1213,6 +1230,7 @@ idx_t ColumnDataCollection::ChunkCount() const { } void ColumnDataCollection::FetchChunk(idx_t chunk_idx, DataChunk &result) const { + D_ASSERT(result.GetTypes() == types); D_ASSERT(chunk_idx < ChunkCount()); for (auto &segment : segments) { if (chunk_idx >= segment->ChunkCount()) { @@ -1354,6 +1372,11 @@ ColumnDataAllocatorType ColumnDataCollection::GetAllocatorType() const { return allocator->GetType(); } +BufferManager &ColumnDataCollection::GetBufferManager() const { + D_ASSERT(allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + return allocator->GetBufferManager(); +} + const vector> &ColumnDataCollection::GetSegments() const { return segments; } diff --git a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp index 1c1195569..9db48b9c4 100644 --- a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp @@ -246,7 +246,7 @@ idx_t ColumnDataCollectionSegment::ReadVector(ChunkManagementState &state, Vecto } } if (state.properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { - VectorOperations::Copy(result, result, vdata.count, 0, 0); + VectorOperations::Copy(result, result, vcount, 0, 0); } } return vcount; diff --git a/src/duckdb/src/common/types/conflict_manager.cpp b/src/duckdb/src/common/types/conflict_manager.cpp index 49d5d1186..9348fd5c0 100644 --- a/src/duckdb/src/common/types/conflict_manager.cpp +++ b/src/duckdb/src/common/types/conflict_manager.cpp @@ -87,7 +87,7 @@ optional_idx ConflictManager::GetFirstInvalidIndex(const idx_t count, const bool for (idx_t i = 0; i < count; i++) { if (negate && !validity.RowIsValid(i)) { return i; - } else if (validity.RowIsValid(i)) { + } else if (!negate && validity.RowIsValid(i)) { return i; } } diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp index 59e7faba7..6b216b185 100644 --- a/src/duckdb/src/common/types/data_chunk.cpp +++ b/src/duckdb/src/common/types/data_chunk.cpp @@ -254,7 +254,6 @@ string DataChunk::ToString() const { } void DataChunk::Serialize(Serializer &serializer, bool compressed_serialization) const { - // write the count auto row_count = size(); serializer.WriteProperty(100, "rows", NumericCast(row_count)); @@ -279,7 +278,6 @@ void DataChunk::Serialize(Serializer &serializer, bool compressed_serialization) } void DataChunk::Deserialize(Deserializer &deserializer) { - // read and set the row count auto row_count = deserializer.ReadProperty(100, "rows"); diff --git a/src/duckdb/src/common/types/decimal.cpp b/src/duckdb/src/common/types/decimal.cpp index 5ecb39a0a..8fa226455 100644 --- a/src/duckdb/src/common/types/decimal.cpp +++ b/src/duckdb/src/common/types/decimal.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/types/cast_helpers.hpp" namespace duckdb { +constexpr uint8_t Decimal::MAX_WIDTH_DECIMAL; template string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { diff --git a/src/duckdb/src/common/types/geometry.cpp b/src/duckdb/src/common/types/geometry.cpp new file mode 100644 index 000000000..d565d36f8 --- /dev/null +++ b/src/duckdb/src/common/types/geometry.cpp @@ -0,0 +1,1143 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector_operations/unary_executor.hpp" +#include "fast_float/fast_float.h" +#include "fmt/format.h" + +//---------------------------------------------------------------------------------------------------------------------- +// Internals +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +namespace { + +class BlobWriter { +public: + template + void Write(const T &value) { + auto ptr = reinterpret_cast(&value); + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + template + struct Reserved { + size_t offset; + T value; + }; + + template + Reserved Reserve() { + auto offset = buffer.size(); + buffer.resize(buffer.size() + sizeof(T)); + return {offset, T()}; + } + + template + void Write(const Reserved &reserved) { + if (reserved.offset + sizeof(T) > buffer.size()) { + throw InternalException("Write out of bounds in BinaryWriter"); + } + auto ptr = reinterpret_cast(&reserved.value); + // We've reserved 0 bytes, so we can safely memcpy + memcpy(buffer.data() + reserved.offset, ptr, sizeof(T)); + } + + void Write(const char *data, size_t size) { + D_ASSERT(data != nullptr); + buffer.insert(buffer.end(), data, data + size); + } + + const vector &GetBuffer() const { + return buffer; + } + + void Clear() { + buffer.clear(); + } + +private: + vector buffer; +}; + +class FixedSizeBlobWriter { +public: + FixedSizeBlobWriter(char *data, uint32_t size) : beg(data), pos(data), end(data + size) { + } + + template + void Write(const T &value) { + if (pos + sizeof(T) > end) { + throw InvalidInputException("Writing beyond end of binary data at position %zu", pos - beg); + } + memcpy(pos, &value, sizeof(T)); + pos += sizeof(T); + } + + void Write(const char *data, size_t size) { + if (pos + size > end) { + throw InvalidInputException("Writing beyond end of binary data at position %zu", pos - beg); + } + memcpy(pos, data, size); + pos += size; + } + + size_t GetPosition() const { + return static_cast(pos - beg); + } + +private: + const char *beg; + char *pos; + const char *end; +}; + +class BlobReader { +public: + BlobReader(const char *data, uint32_t size) : beg(data), pos(data), end(data + size) { + } + + template + T Read(const bool le) { + if (le) { + return Read(); + } else { + return Read(); + } + } + + template + T Read() { + if (pos + sizeof(T) > end) { + throw InvalidInputException("Unexpected end of binary data at position %zu", pos - beg); + } + T value; + if (LE) { + memcpy(&value, pos, sizeof(T)); + pos += sizeof(T); + } else { + char temp[sizeof(T)]; + for (size_t i = 0; i < sizeof(T); ++i) { + temp[i] = pos[sizeof(T) - 1 - i]; + } + memcpy(&value, temp, sizeof(T)); + pos += sizeof(T); + } + return value; + } + + void Skip(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Skipping beyond end of binary data at position %zu", pos - beg); + } + pos += size; + } + + const char *Reserve(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Reserving beyond end of binary data at position %zu", pos - beg); + } + auto current_pos = pos; + pos += size; + return current_pos; + } + + size_t GetPosition() const { + return static_cast(pos - beg); + } + + const char *GetDataPtr() const { + return pos; + } + + bool IsAtEnd() const { + return pos >= end; + } + + void Reset() { + pos = beg; + } + +private: + const char *beg; + const char *pos; + const char *end; +}; + +class TextWriter { +public: + void Write(const char *str) { + buffer.insert(buffer.end(), str, str + strlen(str)); + } + void Write(char c) { + buffer.push_back(c); + } + void Write(double value) { + duckdb_fmt::format_to(std::back_inserter(buffer), "{}", value); + // Remove trailing zero + if (buffer.back() == '0') { + buffer.pop_back(); + if (buffer.back() == '.') { + buffer.pop_back(); + } + } + } + const vector &GetBuffer() const { + return buffer; + } + +private: + vector buffer; +}; + +class TextReader { +public: + TextReader(const char *text, const uint32_t size) : beg(text), pos(text), end(text + size) { + } + + bool TryMatch(const char *str) { + auto ptr = pos; + while (*str && pos < end && tolower(*pos) == tolower(*str)) { + pos++; + str++; + } + if (*str == '\0') { + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } + pos = ptr; // reset position + return false; // not matched + } + + bool TryMatch(char c) { + if (pos < end && tolower(*pos) == tolower(c)) { + pos++; + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } + return false; // not matched + } + + void Match(const char *str) { + if (!TryMatch(str)) { + throw InvalidInputException("Expected '%s' but got '%c' at position %zu", str, *pos, pos - beg); + } + } + + void Match(char c) { + if (!TryMatch(c)) { + throw InvalidInputException("Expected '%c' but got '%c' at position %zu", c, *pos, pos - beg); + } + } + + double MatchNumber() { + // Now use fast_float to parse the number + double num; + const auto res = duckdb_fast_float::from_chars(pos, end, num); + if (res.ec != std::errc()) { + throw InvalidInputException("Expected number at position %zu", pos - beg); + } + + pos = res.ptr; // update position to the end of the parsed number + + SkipWhitespace(); // remove trailing whitespace + return num; // return the parsed number + } + + idx_t GetPosition() const { + return static_cast(pos - beg); + } + + void Reset() { + pos = beg; + } + +private: + void SkipWhitespace() { + while (pos < end && isspace(*pos)) { + pos++; + } + } + + const char *beg; + const char *pos; + const char *end; +}; + +void FromStringRecursive(TextReader &reader, BlobWriter &writer, uint32_t depth, bool parent_has_z, bool parent_has_m) { + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry string exceeds maximum recursion depth of %d", + Geometry::MAX_RECURSION_DEPTH); + } + + GeometryType type; + + if (reader.TryMatch("point")) { + type = GeometryType::POINT; + } else if (reader.TryMatch("linestring")) { + type = GeometryType::LINESTRING; + } else if (reader.TryMatch("polygon")) { + type = GeometryType::POLYGON; + } else if (reader.TryMatch("multipoint")) { + type = GeometryType::MULTIPOINT; + } else if (reader.TryMatch("multilinestring")) { + type = GeometryType::MULTILINESTRING; + } else if (reader.TryMatch("multipolygon")) { + type = GeometryType::MULTIPOLYGON; + } else if (reader.TryMatch("geometrycollection")) { + type = GeometryType::GEOMETRYCOLLECTION; + } else { + throw InvalidInputException("Unknown geometry type at position %zu", reader.GetPosition()); + } + + const auto has_z = reader.TryMatch("z"); + const auto has_m = reader.TryMatch("m"); + + const auto is_empty = reader.TryMatch("empty"); + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + // How many dimensions does this geometry have? + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + + // WKB type + const auto meta = static_cast(type) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + // Write the geometry type and vertex type + writer.Write(1); // LE Byte Order + writer.Write(meta); + + switch (type) { + case GeometryType::POINT: { + if (is_empty) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + reader.Match('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + reader.Match(')'); + } + } break; + case GeometryType::LINESTRING: { + if (is_empty) { + writer.Write(0); // No vertices in empty linestring + break; + } + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + } break; + case GeometryType::POLYGON: { + if (is_empty) { + writer.Write(0); + break; // No rings in empty polygon + } + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + } break; + case GeometryType::MULTIPOINT: { + if (is_empty) { + writer.Write(0); // No points in empty multipoint + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + bool has_paren = reader.TryMatch('('); + + const auto part_meta = static_cast(GeometryType::POINT) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + if (reader.TryMatch("EMPTY")) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + } + if (has_paren) { + reader.Match(')'); // Match the closing parenthesis if it was opened + } + part_count.value++; + } while (reader.TryMatch(',')); + writer.Write(part_count); + } break; + case GeometryType::MULTILINESTRING: { + if (is_empty) { + writer.Write(0); + return; // No linestrings in empty multilinestring + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + const auto part_meta = + static_cast(GeometryType::LINESTRING) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::MULTIPOLYGON: { + if (is_empty) { + writer.Write(0); // No polygons in empty multipolygon + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + const auto part_meta = + static_cast(GeometryType::POLYGON) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + if (is_empty) { + writer.Write(0); // No geometries in empty geometry collection + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + // Recursively parse the geometry inside the collection + FromStringRecursive(reader, writer, depth + 1, has_z, has_m); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + default: + throw InvalidInputException("Unknown geometry type %d at position %zu", static_cast(type), + reader.GetPosition()); + } +} + +void ToStringRecursive(BlobReader &reader, TextWriter &writer, idx_t depth, bool parent_has_z, bool parent_has_m) { + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry exceeds maximum recursion depth of %d", Geometry::MAX_RECURSION_DEPTH); + } + + // Read the byte order (should always be 1 for little-endian) + auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type = static_cast((meta & 0x0000FFFF) % 1000); + const auto flag = (meta & 0x0000FFFF) / 1000; + const auto has_z = (flag & 0x01) != 0; + const auto has_m = (flag & 0x02) != 0; + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + const auto flag_str = has_z ? (has_m ? " ZM " : " Z ") : (has_m ? " M " : " "); + + switch (type) { + case GeometryType::POINT: { + writer.Write("POINT"); + writer.Write(flag_str); + + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + writer.Write(')'); + } break; + case GeometryType::LINESTRING: { + writer.Write("LINESTRING"); + ; + writer.Write(flag_str); + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } break; + case GeometryType::POLYGON: { + writer.Write("POLYGON"); + writer.Write(flag_str); + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOINT: { + writer.Write("MULTIPOINT"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::POINT) { + throw InvalidInputException("Expected POINT in MULTIPOINT but got %d", static_cast(part_type)); + } + + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOINT, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + continue; + } + // writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + // writer.Write(')'); + } + writer.Write(')'); + + } break; + case GeometryType::MULTILINESTRING: { + writer.Write("MULTILINESTRING"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::LINESTRING) { + throw InvalidInputException("Expected LINESTRING in MULTILINESTRING but got %d", + static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTILINESTRING, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOLYGON: { + writer.Write("MULTIPOLYGON"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast((part_meta & 0x0000FFFF) % 1000); + const auto part_flag = (part_meta & 0x0000FFFF) / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + if (part_type != GeometryType::POLYGON) { + throw InvalidInputException("Expected POLYGON in MULTIPOLYGON but got %d", static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOLYGON, starting at position %zu", + reader.GetPosition()); + } + + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + writer.Write("GEOMETRYCOLLECTION"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + // Recursively parse the geometry inside the collection + ToStringRecursive(reader, writer, depth + 1, has_z, has_m); + } + writer.Write(')'); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(type)); + } +} + +struct WKBAnalysis { + uint32_t size = 0; + bool any_be = false; + bool any_z = false; + bool any_m = false; + bool any_unknown = false; + bool any_ewkb = false; +}; + +WKBAnalysis AnalyzeWKB(BlobReader &reader) { + WKBAnalysis result; + + while (!reader.IsAtEnd()) { + const auto le = reader.Read() == 1; + + const auto meta = reader.Read(le); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + // Extended WKB detection + const auto has_extz = (meta & 0x80000000) != 0; + const auto has_extm = (meta & 0x40000000) != 0; + const auto has_srid = (meta & 0x20000000) != 0; + + const auto has_z = ((flag_id & 0x01) != 0) || has_extz; + const auto has_m = ((flag_id & 0x02) != 0) || has_extm; + + if (has_srid) { + result.any_ewkb = true; + reader.Skip(sizeof(uint32_t)); // Skip SRID + // Do not include SRID in the size + } + + if (has_extz || has_extm || has_srid) { + // EWKB flags are set + result.any_ewkb = true; + } + + const auto v_size = (2 + (has_z ? 1 : 0) + (has_m ? 1 : 0)) * sizeof(double); + + result.any_z |= has_z; + result.any_m |= has_m; + result.any_be |= !le; + + result.size += sizeof(uint8_t) + sizeof(uint32_t); // Byte order + type/meta + + switch (type_id) { + case 1: { // POINT + reader.Skip(v_size); + result.size += v_size; + } break; + case 2: { // LINESTRING + const auto vert_count = reader.Read(le); + reader.Skip(vert_count * v_size); + result.size += sizeof(uint32_t) + vert_count * v_size; + } break; + case 3: { // POLYGON + const auto ring_count = reader.Read(le); + result.size += sizeof(uint32_t); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(le); + reader.Skip(vert_count * v_size); + result.size += sizeof(uint32_t) + vert_count * v_size; + } + } break; + case 4: // MULTIPOINT + case 5: // MULTILINESTRING + case 6: // MULTIPOLYGON + case 7: { // GEOMETRYCOLLECTION + reader.Skip(sizeof(uint32_t)); + result.size += sizeof(uint32_t); // part count + } break; + default: { + result.any_unknown = true; + return result; + } + } + } + return result; +} + +void ConvertWKB(BlobReader &reader, FixedSizeBlobWriter &writer) { + while (!reader.IsAtEnd()) { + const auto le = reader.Read() == 1; + const auto meta = reader.Read(le); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + // Extended WKB detection + const auto has_extz = (meta & 0x80000000) != 0; + const auto has_extm = (meta & 0x40000000) != 0; + const auto has_srid = (meta & 0x20000000) != 0; + + const auto has_z = ((flag_id & 0x01) != 0) || has_extz; + const auto has_m = ((flag_id & 0x02) != 0) || has_extm; + + if (has_srid) { + reader.Skip(sizeof(uint32_t)); // Skip SRID + } + + const auto v_width = static_cast((2 + (has_z ? 1 : 0) + (has_m ? 1 : 0))); + + writer.Write(1); // Always write LE + writer.Write(type_id + (1000 * has_z) + (2000 * has_m)); // Write meta + + switch (type_id) { + case 1: { // POINT + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } break; + case 2: { // LINESTRING + const auto vert_count = reader.Read(le); + writer.Write(vert_count); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } + } break; + case 3: { // POLYGON + const auto ring_count = reader.Read(le); + writer.Write(ring_count); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(le); + writer.Write(vert_count); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + for (uint32_t d_idx = 0; d_idx < v_width; d_idx++) { + auto value = reader.Read(le); + writer.Write(value); + } + } + } + } break; + case 4: // MULTIPOINT + case 5: // MULTILINESTRING + case 6: // MULTIPOLYGON + case 7: { // GEOMETRYCOLLECTION + const auto part_count = reader.Read(le); + writer.Write(part_count); + } break; + default: + D_ASSERT(false); + break; + } + } +} + +} // namespace + +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Public interface +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +constexpr const idx_t Geometry::MAX_RECURSION_DEPTH; + +bool Geometry::FromBinary(const string_t &wkb, string_t &result, Vector &result_vector, bool strict) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + const auto analysis = AnalyzeWKB(reader); + if (analysis.any_unknown) { + if (strict) { + throw InvalidInputException("Unsupported geometry type in WKB"); + } + return false; + } + + if (analysis.any_be || analysis.any_ewkb) { + reader.Reset(); + // Make a new WKB with all LE + auto blob = StringVector::EmptyString(result_vector, analysis.size); + FixedSizeBlobWriter writer(blob.GetDataWriteable(), static_cast(blob.GetSize())); + ConvertWKB(reader, writer); + blob.Finalize(); + result = blob; + return true; + } + + // Copy the WKB as-is + result = StringVector::AddStringOrBlob(result_vector, wkb.GetData(), wkb.GetSize()); + return true; +} + +bool Geometry::FromBinary(Vector &source, Vector &result, idx_t count, bool strict) { + if (strict) { + UnaryExecutor::Execute(source, result, count, [&](const string_t &wkb) { + string_t geom; + FromBinary(wkb, geom, result, true); + return geom; + }); + return true; + } + + auto all_ok = true; + UnaryExecutor::ExecuteWithNulls(source, result, count, + [&](const string_t &wkb, ValidityMask &mask, idx_t idx) { + string_t geom; + if (!FromBinary(wkb, geom, result, false)) { + all_ok = false; + mask.SetInvalid(idx); + return string_t(); + } + return geom; + }); + return all_ok; +} + +void Geometry::ToBinary(Vector &source, Vector &result, idx_t count) { + // We are currently using WKB internally, so just copy as-is! + result.Reinterpret(source); +} + +bool Geometry::FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict) { + TextReader reader(wkt_text.GetData(), static_cast(wkt_text.GetSize())); + BlobWriter writer; + + FromStringRecursive(reader, writer, 0, false, false); + + const auto &buffer = writer.GetBuffer(); + result = StringVector::AddStringOrBlob(result_vector, buffer.data(), buffer.size()); + return true; +} + +string_t Geometry::ToString(Vector &result, const string_t &geom) { + BlobReader reader(geom.GetData(), static_cast(geom.GetSize())); + TextWriter writer; + + ToStringRecursive(reader, writer, 0, false, false); + + // Convert the buffer to string_t + const auto &buffer = writer.GetBuffer(); + return StringVector::AddString(result, buffer.data(), buffer.size()); +} + +pair Geometry::GetType(const string_t &wkb) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + // Read the byte order (should always be 1 for little-endian) + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + return {geom_type, vert_type}; +} + +template +static uint32_t ParseVerticesInternal(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, bool check_nan) { + uint32_t count = 0; + + // Issue a single .Reserve() for all vertices, to minimize bounds checking overhead + const auto ptr = const_data_ptr_cast(reader.Reserve(vert_count * sizeof(VERTEX_TYPE))); + + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + VERTEX_TYPE vertex = Load(ptr + vert_idx * sizeof(VERTEX_TYPE)); + if (check_nan && vertex.AllNan()) { + continue; + } + + extent.Extend(vertex); + count++; + } + return count; +} + +static uint32_t ParseVertices(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, VertexType type, + bool check_nan) { + switch (type) { + case VertexType::XY: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZ: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + default: + throw InvalidInputException("Unsupported vertex type %d in WKB", static_cast(type)); + } +} + +uint32_t Geometry::GetExtent(const string_t &wkb, GeometryExtent &extent) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + uint32_t vertex_count = 0; + + while (!reader.IsAtEnd()) { + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + const auto meta = reader.Read(); + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + switch (geom_type) { + case GeometryType::POINT: { + vertex_count += ParseVertices(reader, extent, 1, vert_type, true); + } break; + case GeometryType::LINESTRING: { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } break; + case GeometryType::POLYGON: { + const auto ring_count = reader.Read(); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } + } break; + case GeometryType::MULTIPOINT: + case GeometryType::MULTILINESTRING: + case GeometryType::MULTIPOLYGON: + case GeometryType::GEOMETRYCOLLECTION: { + // Skip count. We don't need it for extent calculation. + reader.Skip(sizeof(uint32_t)); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(geom_type)); + } + } + return vertex_count; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp index 8145cf07f..88c00dbc2 100644 --- a/src/duckdb/src/common/types/list_segment.cpp +++ b/src/duckdb/src/common/types/list_segment.cpp @@ -239,7 +239,6 @@ static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAlloc template static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -258,7 +257,6 @@ static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAlloc static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -297,7 +295,6 @@ static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, Are static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -331,7 +328,6 @@ static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaA static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); // write null validity @@ -376,7 +372,6 @@ static void WriteDataToArraySegment(const ListSegmentFunctions &functions, Arena void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { - auto &write_data_to_segment = *this; auto segment = GetSegment(write_data_to_segment, allocator, linked_list); write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); @@ -391,7 +386,6 @@ void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &link template static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -462,7 +456,6 @@ static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListS static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -503,7 +496,6 @@ static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -528,7 +520,6 @@ static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, con static void ReadDataFromArraySegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); // set NULLs @@ -570,7 +561,6 @@ void SegmentPrimitiveFunction(ListSegmentFunctions &functions) { } void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type) { - if (type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } diff --git a/src/duckdb/src/common/types/row/block_iterator.cpp b/src/duckdb/src/common/types/row/block_iterator.cpp deleted file mode 100644 index bebba60e1..000000000 --- a/src/duckdb/src/common/types/row/block_iterator.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "duckdb/common/types/row/block_iterator.hpp" - -namespace duckdb { - -BlockIteratorStateType GetBlockIteratorStateType(const bool &external) { - return external ? BlockIteratorStateType::EXTERNAL : BlockIteratorStateType::IN_MEMORY; -} - -InMemoryBlockIteratorState::InMemoryBlockIteratorState(const TupleDataCollection &key_data) - : block_ptrs(ConvertBlockPointers(key_data.GetRowBlockPointers())), fast_mod(key_data.TuplesPerBlock()), - tuple_count(key_data.Count()) { -} - -unsafe_vector InMemoryBlockIteratorState::ConvertBlockPointers(const vector &block_ptrs) { - unsafe_vector converted_block_ptrs; - converted_block_ptrs.reserve(block_ptrs.size()); - for (const auto &block_ptr : block_ptrs) { - converted_block_ptrs.emplace_back(block_ptr); - } - return converted_block_ptrs; -} - -ExternalBlockIteratorState::ExternalBlockIteratorState(TupleDataCollection &key_data_p, - optional_ptr payload_data_p) - : tuple_count(key_data_p.Count()), current_chunk_idx(DConstants::INVALID_INDEX), key_data(key_data_p), - key_ptrs(FlatVector::GetData(key_scan_state.chunk_state.row_locations)), payload_data(payload_data_p), - keep_pinned(false), pin_payload(false) { - key_data.InitializeScan(key_scan_state); - if (payload_data) { - payload_data->InitializeScan(payload_scan_state); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp index eff1186b0..69da0d322 100644 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -2,20 +2,20 @@ #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/types/row/tuple_data_iterator.hpp" -#include "duckdb/storage/buffer_manager.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/storage/buffer_manager.hpp" namespace duckdb { PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, - shared_ptr &layout_ptr_p) - : type(type_p), buffer_manager(buffer_manager_p), layout_ptr(layout_ptr_p), layout(*layout_ptr), count(0), - data_size(0) { + shared_ptr &layout_ptr_p, MemoryTag tag_p) + : type(type_p), buffer_manager(buffer_manager_p), + stl_allocator(make_shared_ptr(buffer_manager.GetBufferAllocator())), layout_ptr(layout_ptr_p), + layout(*layout_ptr), tag(tag_p), count(0), data_size(0) { } -PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) - : type(other.type), buffer_manager(other.buffer_manager), layout_ptr(other.layout_ptr), layout(*layout_ptr), - count(0), data_size(0) { +PartitionedTupleData::PartitionedTupleData(PartitionedTupleData &other) + : PartitionedTupleData(other.type, other.buffer_manager, other.layout_ptr, other.tag) { } PartitionedTupleData::~PartitionedTupleData() { @@ -318,7 +318,7 @@ unsafe_vector> &PartitionedTupleData::GetPartiti unique_ptr PartitionedTupleData::GetUnpartitioned() { auto data_collection = std::move(partitions[0]); - partitions[0] = make_uniq(buffer_manager, layout_ptr); + partitions[0] = make_uniq(buffer_manager, layout_ptr, tag); for (idx_t i = 1; i < partitions.size(); i++) { data_collection->Combine(*partitions[i]); diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp deleted file mode 100644 index b178b7fb5..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection.hpp" - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp deleted file mode 100644 index 9b3a4be06..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, - NumericCast(heap_offset)); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + NumericCast(block_idx); - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); -} - -void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - auto &data_block = rows.blocks[block_idx]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); - } -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_layout.cpp b/src/duckdb/src/common/types/row/row_layout.cpp deleted file mode 100644 index 3add8e425..000000000 --- a/src/duckdb/src/common/types/row/row_layout.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/common/types/row/row_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp index 7c5fcd32b..50e04356c 100644 --- a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp @@ -12,8 +12,9 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; -TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p) : capacity(capacity_p), size(0) { - auto buffer_handle = buffer_manager.Allocate(MemoryTag::HASH_TABLE, capacity, false); +TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, MemoryTag tag, idx_t capacity_p) + : capacity(capacity_p), size(0) { + auto buffer_handle = buffer_manager.Allocate(tag, capacity, false); handle = buffer_handle.GetBlockHandle(); } @@ -30,12 +31,14 @@ TupleDataBlock &TupleDataBlock::operator=(TupleDataBlock &&other) noexcept { return *this; } -TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, shared_ptr &layout_ptr_p) - : buffer_manager(buffer_manager), layout_ptr(layout_ptr_p), layout(*layout_ptr) { +TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, shared_ptr layout_ptr_p, + MemoryTag tag_p, shared_ptr stl_allocator_p) + : stl_allocator(std::move(stl_allocator_p)), buffer_manager(buffer_manager), layout_ptr(std::move(layout_ptr_p)), + layout(*layout_ptr), tag(tag_p), row_blocks(*stl_allocator), heap_blocks(*stl_allocator) { } TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) - : buffer_manager(allocator.buffer_manager), layout_ptr(allocator.layout_ptr), layout(*layout_ptr) { + : TupleDataAllocator(allocator.buffer_manager, allocator.layout_ptr, allocator.tag, allocator.stl_allocator) { } void TupleDataAllocator::SetDestroyBufferUponUnpin() { @@ -82,6 +85,10 @@ Allocator &TupleDataAllocator::GetAllocator() { return buffer_manager.GetBufferAllocator(); } +ArenaAllocator &TupleDataAllocator::GetStlAllocator() { + return *stl_allocator; +} + shared_ptr TupleDataAllocator::GetLayoutPtr() const { return layout_ptr; } @@ -116,12 +123,12 @@ bool TupleDataAllocator::BuildFastPath(TupleDataSegment &segment, TupleDataPinSt return false; } - auto &chunk = chunks.back(); + auto &chunk = *chunks.back(); if (chunk.count + append_count > STANDARD_VECTOR_SIZE) { return false; } - auto &part = segment.chunk_parts[chunk.part_ids.End() - 1]; + auto &part = *segment.chunk_parts[chunk.part_ids.End() - 1]; auto &row_block = row_blocks[part.row_block_index]; const auto row_width = layout.GetRowWidth(); @@ -152,23 +159,23 @@ void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin D_ASSERT(this == segment.allocator.get()); auto &chunks = segment.chunks; if (!chunks.empty()) { - ReleaseOrStoreHandles(pin_state, segment, chunks.back(), true); + ReleaseOrStoreHandles(pin_state, segment, *chunks.back(), true); } if (!BuildFastPath(segment, pin_state, chunk_state, append_offset, append_count)) { // Build the chunk parts for the incoming data - chunk_part_indices.clear(); + chunk_state.chunk_part_indices.clear(); idx_t offset = 0; while (offset != append_count) { - if (chunks.empty() || chunks.back().count == STANDARD_VECTOR_SIZE) { - chunks.emplace_back(); + if (chunks.empty() || chunks.back()->count == STANDARD_VECTOR_SIZE) { + chunks.push_back(stl_allocator->MakeUnsafePtr(*stl_allocator->Make())); } - auto &chunk = chunks.back(); + auto &chunk = *chunks.back(); // Build the next part auto next = MinValue(append_count - offset, STANDARD_VECTOR_SIZE - chunk.count); - auto &chunk_part = - chunk.AddPart(segment, BuildChunkPart(pin_state, chunk_state, append_offset + offset, next, chunk)); + auto &chunk_part = chunk.AddPart( + segment, BuildChunkPart(segment, pin_state, chunk_state, append_offset + offset, next, chunk)); next = chunk_part.count; segment.count += next; @@ -190,34 +197,37 @@ void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin } offset += next; - chunk_part_indices.emplace_back(chunks.size() - 1, chunk.part_ids.End() - 1); + chunk_state.chunk_part_indices.emplace_back(chunks.size() - 1, chunk.part_ids.End() - 1); } // Now initialize the pointers to write the data to - chunk_parts.clear(); - for (const auto &indices : chunk_part_indices) { - chunk_parts.emplace_back(segment.chunk_parts[indices.second]); + chunk_state.chunk_parts.clear(); + for (const auto &indices : chunk_state.chunk_part_indices) { + chunk_state.chunk_parts.emplace_back(*segment.chunk_parts[indices.second]); } - InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, chunk_parts); + InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, + chunk_state.chunk_parts); // To reduce metadata, we try to merge chunk parts where possible // Due to the way chunk parts are constructed, only the last part of the first chunk is eligible for merging - segment.chunks[chunk_part_indices[0].first].MergeLastChunkPart(segment); + segment.chunks[chunk_state.chunk_part_indices[0].first]->MergeLastChunkPart(segment); } segment.Verify(); } -TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, - TupleDataChunk &chunk) { +unsafe_arena_ptr +TupleDataAllocator::BuildChunkPart(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, const idx_t append_offset, + const idx_t append_count, TupleDataChunk &chunk) { D_ASSERT(append_count != 0); - TupleDataChunkPart result(*chunk.lock); + auto result_ptr = stl_allocator->MakeUnsafePtr(chunk.lock.get()); + auto &result = *result_ptr; const auto block_size = buffer_manager.GetBlockSize(); // Allocate row block (if needed) if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { - row_blocks.emplace_back(buffer_manager, block_size); + CreateRowBlock(segment); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits row_blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); } @@ -272,7 +282,7 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta // Allocate heap block (if needed) if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { const auto size = MaxValue(block_size, heap_sizes[append_offset]); - heap_blocks.emplace_back(buffer_manager, size); + CreateHeapBlock(segment, size); if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits heap_blocks.back().handle->SetEvictionQueueIndex( RadixPartitioning::RadixBits(partition_index.GetIndex())); @@ -293,14 +303,15 @@ TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_sta // Mark this portion of the row block as filled row_block.size += result.count * layout.GetRowWidth(); - return result; + return result_ptr; } void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, - TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap) { + TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap, + optional_ptr sort_key_payload_state) { D_ASSERT(this == segment.allocator.get()); D_ASSERT(chunk_idx < segment.ChunkCount()); - auto &chunk = segment.chunks[chunk_idx]; + auto &chunk = *segment.chunks[chunk_idx]; // Release or store any handles that are no longer required: // We can't release the heap here if the current chunk's heap_block_ids is empty, because if we are iterating with @@ -308,12 +319,15 @@ void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDa // when chunk 0 needs heap block 0, chunk 1 does not need any heap blocks, and chunk 2 needs heap block 0 again ReleaseOrStoreHandles(pin_state, segment, chunk, !chunk.heap_block_ids.Empty()); - chunk_state.parts.clear(); + chunk_state.chunk_parts.clear(); for (auto part_id = chunk.part_ids.Start(); part_id < chunk.part_ids.End(); part_id++) { - chunk_state.parts.emplace_back(segment.chunk_parts[part_id]); + chunk_state.chunk_parts.emplace_back(*segment.chunk_parts[part_id]); } - InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, chunk_state.parts); + InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, chunk_state.chunk_parts, + sort_key_payload_state); + + chunk_state.chunk_lock = &chunk.lock.get(); } static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t heap_sizes[], const idx_t offset, @@ -335,13 +349,56 @@ static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t h #endif } +template +void TemplatedSortKeySetPayload(const data_ptr_t row_locations[], const idx_t offset, const idx_t count, + TupleDataChunkState &sort_key_chunk_state) { + using SORT_KEY = SortKey; + const auto sort_keys = FlatVector::GetData(sort_key_chunk_state.row_locations); + + lock_guard guard(*sort_key_chunk_state.chunk_lock); + if (sort_keys[offset]->GetPayload() == row_locations[offset]) { + return; // Still the same + } + + // Changed: set new pointers + for (idx_t i = offset; i < offset + count; i++) { + sort_keys[i]->SetPayload(row_locations[i]); + } +} + +void SortKeySetPayload(const data_ptr_t row_locations[], const idx_t offset, const idx_t count, + const SortKeyPayloadState &sort_key_payload_state) { + switch (sort_key_payload_state.sort_key_type) { + case SortKeyType::PAYLOAD_FIXED_16: + TemplatedSortKeySetPayload(row_locations, offset, count, + sort_key_payload_state.sort_key_chunk_state); + break; + case SortKeyType::PAYLOAD_FIXED_24: + TemplatedSortKeySetPayload(row_locations, offset, count, + sort_key_payload_state.sort_key_chunk_state); + break; + case SortKeyType::PAYLOAD_FIXED_32: + TemplatedSortKeySetPayload(row_locations, offset, count, + sort_key_payload_state.sort_key_chunk_state); + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + TemplatedSortKeySetPayload(row_locations, offset, count, + sort_key_payload_state.sort_key_chunk_state); + break; + default: + throw NotImplementedException("SortKeySetPayload for %s", + EnumUtil::ToString(sort_key_payload_state.sort_key_type)); + } +} + void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t offset, bool recompute, bool init_heap_pointers, bool init_heap_sizes, - unsafe_vector> &parts) { - auto row_locations = FlatVector::GetData(chunk_state.row_locations); - auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); + unsafe_vector> &parts, + optional_ptr sort_key_payload_state) { + const auto row_locations = FlatVector::GetData(chunk_state.row_locations); + const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); + const auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); for (auto &part_ref : parts) { auto &part = part_ref.get(); @@ -354,6 +411,12 @@ void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_sta row_locations[offset + i] = base_row_ptr + i * row_width; } + if (sort_key_payload_state) { + D_ASSERT(!layout.IsSortKeyLayout()); + lock_guard guard(part.lock); + SortKeySetPayload(row_locations, offset, next, *sort_key_payload_state); + } + if (layout.AllConstant()) { // Can't have a heap offset += next; continue; @@ -670,14 +733,15 @@ void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, Tup } void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment) { - static TupleDataChunk DUMMY_CHUNK; + mutex dummy_chunk_mutex; + static TupleDataChunk DUMMY_CHUNK(dummy_chunk_mutex); ReleaseOrStoreHandles(pin_state, segment, DUMMY_CHUNK, true); } void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, - unsafe_vector &pinned_handles, + unsafe_arena_vector &pinned_handles, buffer_handle_map_t &handles, const ContinuousIdSet &block_ids, - unsafe_vector &blocks, + unsafe_arena_vector &blocks, TupleDataPinProperties properties) { bool found_handle; do { @@ -691,10 +755,7 @@ void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment switch (properties) { case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: { lock_guard guard(segment.pinned_handles_lock); - const auto block_count = block_id + 1; - if (block_count > pinned_handles.size()) { - pinned_handles.resize(block_count); - } + D_ASSERT(blocks.size() == pinned_handles.size()); pinned_handles[block_id] = std::move(it->second); break; } @@ -718,6 +779,16 @@ void TupleDataAllocator::ReleaseOrStoreHandlesInternal(TupleDataSegment &segment } while (found_handle); } +void TupleDataAllocator::CreateRowBlock(TupleDataSegment &segment) { + row_blocks.emplace_back(buffer_manager, tag, buffer_manager.GetBlockSize()); + segment.pinned_row_handles.resize(row_blocks.size()); +} + +void TupleDataAllocator::CreateHeapBlock(TupleDataSegment &segment, idx_t size) { + heap_blocks.emplace_back(buffer_manager, tag, size); + segment.pinned_heap_handles.resize(heap_blocks.size()); +} + BufferHandle &TupleDataAllocator::PinRowBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { const auto &row_block_index = part.row_block_index; auto it = pin_state.row_handles.find(row_block_index); diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index ffd4a2b4c..ddb41d884 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -12,12 +12,22 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; -TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr_p) - : layout_ptr(std::move(layout_ptr_p)), layout(*layout_ptr), - allocator(make_shared_ptr(buffer_manager, layout_ptr)) { +TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr_p, + MemoryTag tag_p, shared_ptr stl_allocator_p) + : stl_allocator(stl_allocator_p ? std::move(stl_allocator_p) + : make_shared_ptr(buffer_manager.GetBufferAllocator())), + layout_ptr(std::move(layout_ptr_p)), layout(*layout_ptr), tag(tag_p), + allocator(make_shared_ptr(buffer_manager, layout_ptr, tag, stl_allocator)), + segments(*stl_allocator), scatter_functions(*stl_allocator), gather_functions(*stl_allocator) { Initialize(); } +TupleDataCollection::TupleDataCollection(ClientContext &context, shared_ptr layout_ptr, MemoryTag tag, + shared_ptr stl_allocator) + : TupleDataCollection(BufferManager::GetBufferManager(context), std::move(layout_ptr), tag, + std::move(stl_allocator)) { +} + TupleDataCollection::~TupleDataCollection() { } @@ -40,7 +50,7 @@ void TupleDataCollection::Initialize() { } unique_ptr TupleDataCollection::CreateUnique() const { - return make_uniq(allocator->GetBufferManager(), layout_ptr); + return make_uniq(allocator->GetBufferManager(), layout_ptr, tag); } void GetAllColumnIDsInternal(vector &column_ids, const idx_t column_count) { @@ -110,13 +120,13 @@ void TupleDataCollection::DestroyChunks(const idx_t chunk_idx_begin, const idx_t D_ASSERT(segments.size() == 1); // Assume 1 segment for now (multi-segment destroys can be implemented if needed) D_ASSERT(chunk_idx_begin <= chunk_idx_end && chunk_idx_end <= ChunkCount()); auto &segment = *segments[0]; - auto &chunk_begin = segment.chunks[chunk_idx_begin]; + auto &chunk_begin = *segment.chunks[chunk_idx_begin]; const auto row_block_begin = chunk_begin.row_block_ids.Start(); if (chunk_idx_end == ChunkCount()) { segment.allocator->DestroyRowBlocks(row_block_begin, segment.allocator->RowBlockCount()); } else { - auto &chunk_end = segment.chunks[chunk_idx_end]; + auto &chunk_end = *segment.chunks[chunk_idx_end]; const auto row_block_end = chunk_end.row_block_ids.Start(); segment.allocator->DestroyRowBlocks(row_block_begin, row_block_end); } @@ -129,7 +139,7 @@ void TupleDataCollection::DestroyChunks(const idx_t chunk_idx_begin, const idx_t if (chunk_idx_end == ChunkCount()) { segment.allocator->DestroyHeapBlocks(heap_block_begin, segment.allocator->HeapBlockCount()); } else { - auto &chunk_end = segment.chunks[chunk_idx_end]; + auto &chunk_end = *segment.chunks[chunk_idx_end]; if (chunk_end.heap_block_ids.Empty()) { return; } @@ -180,7 +190,7 @@ void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, v void TupleDataCollection::InitializeAppend(TupleDataPinState &pin_state, TupleDataPinProperties properties) { pin_state.properties = properties; if (segments.empty()) { - segments.emplace_back(make_unsafe_uniq(allocator)); + segments.emplace_back(stl_allocator->MakeUnsafePtr(allocator)); } } @@ -469,7 +479,7 @@ void TupleDataCollection::Combine(TupleDataCollection &other) { other.Reset(); } -void TupleDataCollection::AddSegment(unsafe_unique_ptr segment) { +void TupleDataCollection::AddSegment(unsafe_arena_ptr segment) { count += segment->count; data_size += segment->data_size; segments.emplace_back(std::move(segment)); @@ -504,7 +514,7 @@ void TupleDataCollection::InitializeChunk(DataChunk &chunk, const vectorGetAllocator(), chunk_types); } -void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { +void TupleDataCollection::InitializeScanChunk(const TupleDataScanState &state, DataChunk &chunk) const { auto &column_ids = state.chunk_state.column_ids; D_ASSERT(!column_ids.empty()); vector chunk_types; @@ -562,11 +572,18 @@ void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vect InitializeScan(state.scan_state, std::move(column_ids), properties); } -idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, const idx_t segment_idx, const idx_t chunk_idx, - const bool init_heap) { - auto &segment = *segments[segment_idx]; - allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap); - return segment.chunks[chunk_idx].count; +idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap, + optional_ptr sort_key_payload_state) { + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + auto &segment = *segments[segment_idx]; + if (chunk_idx < segment.ChunkCount()) { + segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap, + sort_key_payload_state); + return segment.chunks[chunk_idx]->count; + } + chunk_idx -= segment.ChunkCount(); + } + throw InternalException("Chunk index out of in TupleDataCollection::FetchChunk"); } bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { @@ -609,6 +626,43 @@ bool TupleDataCollection::Scan(TupleDataParallelScanState &gstate, TupleDataLoca return true; } +idx_t TupleDataCollection::Seek(TupleDataScanState &state, const idx_t target_chunk) { + D_ASSERT(state.pin_state.properties == TupleDataPinProperties::UNPIN_AFTER_DONE); + state.pin_state.row_handles.clear(); + state.pin_state.heap_handles.clear(); + + // early return for empty collection + if (segments.empty()) { + return 0; + } + + idx_t current_chunk = 0; + idx_t total_rows = 0; + for (idx_t seg_idx = 0; seg_idx < segments.size(); seg_idx++) { + auto &segment = segments[seg_idx]; + idx_t chunk_count = segment->ChunkCount(); + + if (current_chunk + chunk_count <= target_chunk) { + total_rows += segment->count; + current_chunk += chunk_count; + } else { + idx_t chunk_idx_in_segment = target_chunk - current_chunk; + for (idx_t chunk_idx = 0; chunk_idx < chunk_idx_in_segment; chunk_idx++) { + total_rows += segment->chunks[chunk_idx]->count; + } + current_chunk += chunk_count; + + // reset scan state to target segment + state.segment_index = seg_idx; + state.chunk_index = chunk_idx_in_segment; + break; + } + } + + D_ASSERT(target_chunk < current_chunk); + return total_rows; +} + bool TupleDataCollection::ScanComplete(const TupleDataScanState &state) const { if (Count() == 0) { return true; @@ -648,7 +702,7 @@ void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChu const vector &column_ids, idx_t segment_index, idx_t chunk_index, DataChunk &result) { auto &segment = *segments[segment_index]; - auto &chunk = segment.chunks[chunk_index]; + const auto &chunk = *segment.chunks[chunk_index]; segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); result.Reset(); diff --git a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp index 03dd5db23..5bbe7841d 100644 --- a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp @@ -74,7 +74,7 @@ void TupleDataChunkIterator::Reset() { } idx_t TupleDataChunkIterator::GetCurrentChunkCount() const { - return collection.segments[current_segment_idx]->chunks[current_chunk_idx].count; + return collection.segments[current_segment_idx]->chunks[current_chunk_idx]->count; } TupleDataChunkState &TupleDataChunkIterator::GetChunkState() { diff --git a/src/duckdb/src/common/types/row/tuple_data_layout.cpp b/src/duckdb/src/common/types/row/tuple_data_layout.cpp index 45e6420b2..75367a0ef 100644 --- a/src/duckdb/src/common/types/row/tuple_data_layout.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_layout.cpp @@ -131,7 +131,7 @@ void TupleDataLayout::Initialize(vector types_p, Aggregates aggrega for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { const auto &aggr = aggregates[aggr_idx]; - if (aggr.function.destructor) { + if (aggr.function.HasStateDestructorCallback()) { aggr_destructor_idxs.push_back(aggr_idx); } } diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp index fe671a46f..3c967d448 100644 --- a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp @@ -1794,7 +1794,6 @@ static void TupleDataCastToArrayStructGather(const TupleDataLayout &layout, Vect const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, const SelectionVector &target_sel, optional_ptr cached_cast_vector, const vector &child_functions) { - if (cached_cast_vector) { // Reuse the cached cast vector TupleDataStructGather(layout, row_locations, col_idx, scan_sel, scan_count, *cached_cast_vector, target_sel, diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp index 462e7e474..be6901670 100644 --- a/src/duckdb/src/common/types/row/tuple_data_segment.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_segment.cpp @@ -15,7 +15,7 @@ void TupleDataChunkPart::SetHeapEmpty() { base_heap_ptr = nullptr; } -TupleDataChunk::TupleDataChunk() : count(0), lock(make_unsafe_uniq()) { +TupleDataChunk::TupleDataChunk(mutex &lock_p) : count(0), lock(lock_p) { } static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noexcept { @@ -26,7 +26,7 @@ static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noex std::swap(a.lock, b.lock); } -TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept : count(0) { +TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept : count(0), lock(other.lock) { SwapTupleDataChunk(*this, other); } @@ -35,23 +35,24 @@ TupleDataChunk &TupleDataChunk::operator=(TupleDataChunk &&other) noexcept { return *this; } -TupleDataChunkPart &TupleDataChunk::AddPart(TupleDataSegment &segment, TupleDataChunkPart &&part) { +TupleDataChunkPart &TupleDataChunk::AddPart(TupleDataSegment &segment, unsafe_arena_ptr part_ptr) { + auto &part = *part_ptr; count += part.count; row_block_ids.Insert(part.row_block_index); if (!segment.layout.AllConstant() && part.total_heap_size > 0) { heap_block_ids.Insert(part.heap_block_index); } - part.lock = *lock; + part.lock = lock; part_ids.Insert(UnsafeNumericCast(segment.chunk_parts.size())); - segment.chunk_parts.emplace_back(std::move(part)); - return segment.chunk_parts.back(); + segment.chunk_parts.emplace_back(std::move(part_ptr)); + return part; } void TupleDataChunk::Verify(const TupleDataSegment &segment) const { #ifdef D_ASSERT_IS_ENABLED idx_t total_count = 0; for (auto part_id = part_ids.Start(); part_id < part_ids.End(); part_id++) { - total_count += segment.chunk_parts[part_id].count; + total_count += segment.chunk_parts[part_id]->count; } D_ASSERT(this->count == total_count); D_ASSERT(this->count <= STANDARD_VECTOR_SIZE); @@ -63,8 +64,8 @@ void TupleDataChunk::MergeLastChunkPart(TupleDataSegment &segment) { return; } - auto &second_to_last = segment.chunk_parts[part_ids.End() - 2]; - auto &last = segment.chunk_parts[part_ids.End() - 1]; + auto &second_to_last = *segment.chunk_parts[part_ids.End() - 2]; + auto &last = *segment.chunk_parts[part_ids.End() - 1]; auto rows_align = last.row_block_index == second_to_last.row_block_index && @@ -98,11 +99,8 @@ void TupleDataChunk::MergeLastChunkPart(TupleDataSegment &segment) { } TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) - : allocator(std::move(allocator_p)), layout(allocator->GetLayout()), count(0), data_size(0) { - // We initialize these with plenty of room so that we can avoid allocations - static constexpr idx_t CHUNK_RESERVATION = 64; - chunks.reserve(CHUNK_RESERVATION); - chunk_parts.reserve(CHUNK_RESERVATION); + : allocator(std::move(allocator_p)), layout(allocator->GetLayout()), count(0), data_size(0), + pinned_row_handles(allocator->GetStlAllocator()), pinned_heap_handles(allocator->GetStlAllocator()) { } TupleDataSegment::~TupleDataSegment() { @@ -112,7 +110,6 @@ TupleDataSegment::~TupleDataSegment() { } pinned_row_handles.clear(); pinned_heap_handles.clear(); - allocator.reset(); } idx_t TupleDataSegment::ChunkCount() const { @@ -131,18 +128,19 @@ void TupleDataSegment::Unpin() { void TupleDataSegment::Verify() const { #ifdef D_ASSERT_IS_ENABLED - const auto &layout = allocator->GetLayout(); + const auto &allocator_layout = allocator->GetLayout(); idx_t total_count = 0; idx_t total_size = 0; - for (const auto &chunk : chunks) { + for (const auto &chunk_ptr : chunks) { + const auto &chunk = *chunk_ptr; chunk.Verify(*this); total_count += chunk.count; - total_size += chunk.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { + total_size += chunk.count * allocator_layout.GetRowWidth(); + if (!allocator_layout.AllConstant()) { for (auto part_id = chunk.part_ids.Start(); part_id < chunk.part_ids.End(); part_id++) { - total_size += chunk_parts[part_id].total_heap_size; + total_size += chunk_parts[part_id]->total_heap_size; } } } diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp index 145b6bfa1..a1232340c 100644 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -50,6 +50,14 @@ buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx return data; } +idx_t SelectionVector::SliceInPlace(const SelectionVector &source, idx_t count) { + for (idx_t i = 0; i < count; ++i) { + set_index(i, get_index(source.get_index(i))); + } + + return count; +} + void SelectionVector::Verify(idx_t count, idx_t vector_size) const { #ifdef DEBUG D_ASSERT(vector_size >= 1); diff --git a/src/duckdb/src/common/types/string_type.cpp b/src/duckdb/src/common/types/string_type.cpp index f5a236557..bea85327a 100644 --- a/src/duckdb/src/common/types/string_type.cpp +++ b/src/duckdb/src/common/types/string_type.cpp @@ -6,6 +6,8 @@ #include "utf8proc_wrapper.hpp" namespace duckdb { +constexpr idx_t string_t::MAX_STRING_SIZE; +constexpr idx_t string_t::INLINE_LENGTH; void string_t::Verify() const { #ifdef DEBUG diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index 2bef3a82d..f1352331f 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -919,6 +919,14 @@ Value Value::BIGNUM(const string &data) { return result; } +Value Value::GEOMETRY(const_data_ptr_t data, idx_t len) { + Value result; + result.type_ = LogicalType::GEOMETRY(); // construct type explicitly so that we get the ExtraTypeInfo + result.is_null = false; + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); + return result; +} + Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; diff --git a/src/duckdb/src/common/types/variant/variant.cpp b/src/duckdb/src/common/types/variant/variant.cpp new file mode 100644 index 000000000..ba7addcf7 --- /dev/null +++ b/src/duckdb/src/common/types/variant/variant.cpp @@ -0,0 +1,76 @@ +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/types/vector.hpp" + +namespace duckdb { + +VariantVectorData::VariantVectorData(Vector &variant) + : variant(variant), keys_index_validity(FlatVector::Validity(VariantVector::GetChildrenKeysIndex(variant))), + keys(VariantVector::GetKeys(variant)) { + blob_data = FlatVector::GetData(VariantVector::GetData(variant)); + type_ids_data = FlatVector::GetData(VariantVector::GetValuesTypeId(variant)); + byte_offset_data = FlatVector::GetData(VariantVector::GetValuesByteOffset(variant)); + keys_index_data = FlatVector::GetData(VariantVector::GetChildrenKeysIndex(variant)); + values_index_data = FlatVector::GetData(VariantVector::GetChildrenValuesIndex(variant)); + values_data = FlatVector::GetData(VariantVector::GetValues(variant)); + children_data = FlatVector::GetData(VariantVector::GetChildren(variant)); + keys_data = FlatVector::GetData(keys); +} + +UnifiedVariantVectorData::UnifiedVariantVectorData(const RecursiveUnifiedVectorFormat &variant) + : variant(variant), keys(UnifiedVariantVector::GetKeys(variant)), + keys_entry(UnifiedVariantVector::GetKeysEntry(variant)), children(UnifiedVariantVector::GetChildren(variant)), + keys_index(UnifiedVariantVector::GetChildrenKeysIndex(variant)), + values_index(UnifiedVariantVector::GetChildrenValuesIndex(variant)), + values(UnifiedVariantVector::GetValues(variant)), type_id(UnifiedVariantVector::GetValuesTypeId(variant)), + byte_offset(UnifiedVariantVector::GetValuesByteOffset(variant)), data(UnifiedVariantVector::GetData(variant)), + keys_index_validity(keys_index.validity) { + blob_data = data.GetData(); + type_id_data = type_id.GetData(); + byte_offset_data = byte_offset.GetData(); + keys_index_data = keys_index.GetData(); + values_index_data = values_index.GetData(); + values_data = values.GetData(); + children_data = children.GetData(); + keys_data = keys.GetData(); + keys_entry_data = keys_entry.GetData(); +} + +bool UnifiedVariantVectorData::RowIsValid(idx_t row) const { + return variant.unified.validity.RowIsValid(variant.unified.sel->get_index(row)); +} +bool UnifiedVariantVectorData::KeysIndexIsValid(idx_t row, idx_t index) const { + auto list_entry = GetChildrenListEntry(row); + return keys_index_validity.RowIsValid(keys_index.sel->get_index(list_entry.offset + index)); +} + +list_entry_t UnifiedVariantVectorData::GetChildrenListEntry(idx_t row) const { + return children_data[children.sel->get_index(row)]; +} +list_entry_t UnifiedVariantVectorData::GetValuesListEntry(idx_t row) const { + return values_data[values.sel->get_index(row)]; +} +const string_t &UnifiedVariantVectorData::GetKey(idx_t row, idx_t index) const { + auto list_entry = keys_data[keys.sel->get_index(row)]; + return keys_entry_data[keys_entry.sel->get_index(list_entry.offset + index)]; +} +uint32_t UnifiedVariantVectorData::GetKeysIndex(idx_t row, idx_t child_index) const { + auto list_entry = GetChildrenListEntry(row); + return keys_index_data[keys_index.sel->get_index(list_entry.offset + child_index)]; +} +uint32_t UnifiedVariantVectorData::GetValuesIndex(idx_t row, idx_t child_index) const { + auto list_entry = GetChildrenListEntry(row); + return values_index_data[values_index.sel->get_index(list_entry.offset + child_index)]; +} +VariantLogicalType UnifiedVariantVectorData::GetTypeId(idx_t row, idx_t value_index) const { + auto list_entry = values_data[values.sel->get_index(row)]; + return static_cast(type_id_data[type_id.sel->get_index(list_entry.offset + value_index)]); +} +uint32_t UnifiedVariantVectorData::GetByteOffset(idx_t row, idx_t value_index) const { + auto list_entry = values_data[values.sel->get_index(row)]; + return byte_offset_data[byte_offset.sel->get_index(list_entry.offset + value_index)]; +} +const string_t &UnifiedVariantVectorData::GetData(idx_t row) const { + return blob_data[data.sel->get_index(row)]; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/variant/variant_value.cpp b/src/duckdb/src/common/types/variant/variant_value.cpp new file mode 100644 index 000000000..fd5734e19 --- /dev/null +++ b/src/duckdb/src/common/types/variant/variant_value.cpp @@ -0,0 +1,784 @@ +#include "duckdb/common/types/variant_value.hpp" +#include "yyjson.hpp" + +#include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/datetime.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +void VariantValue::AddChild(const string &key, VariantValue &&val) { + D_ASSERT(value_type == VariantValueType::OBJECT); + object_children.emplace(key, std::move(val)); +} + +void VariantValue::AddItem(VariantValue &&val) { + D_ASSERT(value_type == VariantValueType::ARRAY); + array_items.push_back(std::move(val)); +} + +static void AnalyzeValue(const VariantValue &value, idx_t row, DataChunk &offsets) { + auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; + auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; + auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; + auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; + + values_offset++; + switch (value.value_type) { + case VariantValueType::OBJECT: { + //! Write the count of the children + auto &children = value.object_children; + data_offset += GetVarintSize(children.size()); + if (!children.empty()) { + //! Write the children offset + data_offset += GetVarintSize(children_offset); + children_offset += children.size(); + keys_offset += children.size(); + for (auto &child : children) { + auto &child_value = child.second; + AnalyzeValue(child_value, row, offsets); + } + } + break; + } + case VariantValueType::ARRAY: { + //! Write the count of the children + auto &children = value.array_items; + data_offset += GetVarintSize(children.size()); + if (!children.empty()) { + //! Write the children offset + data_offset += GetVarintSize(children_offset); + children_offset += children.size(); + for (auto &child : children) { + AnalyzeValue(child, row, offsets); + } + } + break; + } + case VariantValueType::PRIMITIVE: { + auto &primitive = value.primitive_value; + auto type_id = primitive.type().id(); + switch (type_id) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::SQLNULL: { + break; + } + case LogicalTypeId::TINYINT: { + data_offset += sizeof(int8_t); + break; + } + case LogicalTypeId::SMALLINT: { + data_offset += sizeof(int16_t); + break; + } + case LogicalTypeId::INTEGER: { + data_offset += sizeof(int32_t); + break; + } + case LogicalTypeId::BIGINT: { + data_offset += sizeof(int64_t); + break; + } + case LogicalTypeId::HUGEINT: { + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::UTINYINT: { + data_offset += sizeof(uint8_t); + break; + } + case LogicalTypeId::USMALLINT: { + data_offset += sizeof(uint16_t); + break; + } + case LogicalTypeId::UINTEGER: { + data_offset += sizeof(uint32_t); + break; + } + case LogicalTypeId::UBIGINT: { + data_offset += sizeof(uint64_t); + break; + } + case LogicalTypeId::UHUGEINT: { + data_offset += sizeof(uhugeint_t); + break; + } + case LogicalTypeId::DOUBLE: { + data_offset += sizeof(double); + break; + } + case LogicalTypeId::FLOAT: { + data_offset += sizeof(float); + break; + } + case LogicalTypeId::DATE: { + data_offset += sizeof(date_t); + break; + } + case LogicalTypeId::TIMESTAMP_TZ: { + data_offset += sizeof(timestamp_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP: { + data_offset += sizeof(timestamp_t); + break; + } + case LogicalTypeId::TIMESTAMP_SEC: { + data_offset += sizeof(timestamp_sec_t); + break; + } + case LogicalTypeId::TIMESTAMP_MS: { + data_offset += sizeof(timestamp_ms_t); + break; + } + case LogicalTypeId::TIME: { + data_offset += sizeof(dtime_t); + break; + } + case LogicalTypeId::TIME_NS: { + data_offset += sizeof(dtime_ns_t); + break; + } + case LogicalTypeId::TIME_TZ: { + data_offset += sizeof(dtime_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP_NS: { + data_offset += sizeof(timestamp_ns_t); + break; + } + case LogicalTypeId::INTERVAL: { + data_offset += sizeof(interval_t); + break; + } + case LogicalTypeId::UUID: { + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::DECIMAL: { + auto &type = primitive.type(); + uint8_t width; + uint8_t scale; + type.GetDecimalProperties(width, scale); + + auto physical_type = type.InternalType(); + data_offset += GetVarintSize(width); + data_offset += GetVarintSize(scale); + switch (physical_type) { + case PhysicalType::INT16: { + data_offset += sizeof(int16_t); + break; + } + case PhysicalType::INT32: { + data_offset += sizeof(int32_t); + break; + } + case PhysicalType::INT64: { + data_offset += sizeof(int64_t); + break; + } + case PhysicalType::INT128: { + data_offset += sizeof(hugeint_t); + break; + } + default: + throw InternalException("Unexpected physical type for Decimal value: %s", + EnumUtil::ToString(physical_type)); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::BIGNUM: + case LogicalTypeId::BIT: + case LogicalTypeId::GEOMETRY: + case LogicalTypeId::VARCHAR: { + auto string_data = primitive.GetValueUnsafe(); + data_offset += GetVarintSize(string_data.GetSize()); + data_offset += string_data.GetSize(); + break; + } + default: + throw InternalException("Encountered unrecognized LogicalType in VariantValue::AnalyzeValue: %s", + primitive.type().ToString()); + } + break; + } + default: + throw InternalException("VariantValueType not handled"); + } +} + +uint32_t GetOrCreateIndex(OrderedOwningStringMap &dictionary, const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; +} + +static void ConvertValue(const VariantValue &value, VariantVectorData &result, idx_t row, DataChunk &offsets, + SelectionVector &keys_selvec, OrderedOwningStringMap &dictionary) { + auto blob_data = data_ptr_cast(result.blob_data[row].GetDataWriteable()); + auto keys_list_offset = result.keys_data[row].offset; + auto children_list_offset = result.children_data[row].offset; + auto values_list_offset = result.values_data[row].offset; + + auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; + auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; + auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; + auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; + + switch (value.value_type) { + case VariantValueType::OBJECT: { + //! Write the count of the children + auto &children = value.object_children; + + //! values + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::OBJECT); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + values_offset++; + + //! data + VarintEncode(static_cast(children.size()), blob_data + data_offset); + data_offset += GetVarintSize(children.size()); + + if (!children.empty()) { + //! Write the children offset + VarintEncode(children_offset, blob_data + data_offset); + data_offset += GetVarintSize(children_offset); + + auto start_of_children = children_offset; + children_offset += children.size(); + + auto it = children.begin(); + for (idx_t i = 0; i < children.size(); i++) { + //! children + result.keys_index_data[children_list_offset + start_of_children + i] = keys_offset; + result.values_index_data[children_list_offset + start_of_children + i] = values_offset; + + auto &child = *it; + //! keys + auto &child_key = child.first; + auto dictionary_index = GetOrCreateIndex(dictionary, child_key); + keys_selvec.set_index(keys_list_offset + keys_offset, dictionary_index); + keys_offset++; + + auto &child_value = child.second; + ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); + it++; + } + } + break; + } + case VariantValueType::ARRAY: { + //! Write the count of the children + auto &children = value.array_items; + + //! values + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::ARRAY); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + values_offset++; + + //! data + VarintEncode(static_cast(children.size()), blob_data + data_offset); + data_offset += GetVarintSize(children.size()); + + if (!children.empty()) { + //! Write the children offset + VarintEncode(children_offset, blob_data + data_offset); + data_offset += GetVarintSize(children_offset); + + auto start_of_children = children_offset; + children_offset += children.size(); + + for (idx_t i = 0; i < children.size(); i++) { + //! children + result.keys_index_validity.SetInvalid(children_list_offset + start_of_children + i); + result.values_index_data[children_list_offset + start_of_children + i] = values_offset; + + auto &child_value = children[i]; + ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); + } + } + break; + } + case VariantValueType::PRIMITIVE: { + auto &primitive = value.primitive_value; + auto type_id = primitive.type().id(); + result.byte_offset_data[values_list_offset + values_offset] = data_offset; + switch (type_id) { + case LogicalTypeId::BOOLEAN: { + if (primitive.GetValue()) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BOOL_TRUE); + } else { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BOOL_FALSE); + } + break; + } + case LogicalTypeId::SQLNULL: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::VARIANT_NULL); + break; + } + case LogicalTypeId::TINYINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT8); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int8_t); + break; + } + case LogicalTypeId::SMALLINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT16); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int16_t); + break; + } + case LogicalTypeId::INTEGER: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT32); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int32_t); + break; + } + case LogicalTypeId::BIGINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT64); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int64_t); + break; + } + case LogicalTypeId::HUGEINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT128); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::UTINYINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT8); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(uint8_t); + break; + } + case LogicalTypeId::USMALLINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT16); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(uint16_t); + break; + } + case LogicalTypeId::UINTEGER: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT32); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(uint32_t); + break; + } + case LogicalTypeId::UBIGINT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT64); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(uint64_t); + break; + } + case LogicalTypeId::UHUGEINT: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::UINT128); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(uhugeint_t); + break; + } + case LogicalTypeId::DOUBLE: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DOUBLE); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(double); + break; + } + case LogicalTypeId::FLOAT: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::FLOAT); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(float); + break; + } + case LogicalTypeId::DATE: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DATE); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(date_t); + break; + } + case LogicalTypeId::TIMESTAMP_TZ: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_MICROS_TZ); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_MICROS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_t); + break; + } + case LogicalTypeId::TIMESTAMP_SEC: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_SEC); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_sec_t); + break; + } + case LogicalTypeId::TIMESTAMP_MS: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_MILIS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_ms_t); + break; + } + case LogicalTypeId::TIME: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIME_MICROS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(dtime_t); + break; + } + case LogicalTypeId::TIME_NS: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIME_NANOS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(dtime_ns_t); + break; + } + case LogicalTypeId::TIME_TZ: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIME_MICROS_TZ); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(dtime_tz_t); + break; + } + case LogicalTypeId::TIMESTAMP_NS: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::TIMESTAMP_NANOS); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(timestamp_ns_t); + break; + } + case LogicalTypeId::INTERVAL: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::INTERVAL); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(interval_t); + break; + } + case LogicalTypeId::UUID: { + result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UUID); + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(hugeint_t); + break; + } + case LogicalTypeId::DECIMAL: { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::DECIMAL); + auto &type = primitive.type(); + uint8_t width; + uint8_t scale; + type.GetDecimalProperties(width, scale); + + auto physical_type = type.InternalType(); + VarintEncode(width, blob_data + data_offset); + data_offset += GetVarintSize(width); + VarintEncode(scale, blob_data + data_offset); + data_offset += GetVarintSize(scale); + switch (physical_type) { + case PhysicalType::INT16: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int16_t); + break; + } + case PhysicalType::INT32: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int32_t); + break; + } + case PhysicalType::INT64: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(int64_t); + break; + } + case PhysicalType::INT128: { + Store(primitive.GetValueUnsafe(), blob_data + data_offset); + data_offset += sizeof(hugeint_t); + break; + } + default: + throw InternalException("Unexpected physical type for Decimal value: %s", + EnumUtil::ToString(physical_type)); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::BIGNUM: + case LogicalTypeId::BIT: + case LogicalTypeId::GEOMETRY: + case LogicalTypeId::VARCHAR: { + if (type_id == LogicalTypeId::BLOB) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BLOB); + } else if (type_id == LogicalTypeId::BIGNUM) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BIGNUM); + } else if (type_id == LogicalTypeId::BIT) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::BITSTRING); + } else if (type_id == LogicalTypeId::GEOMETRY) { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::GEOMETRY); + } else { + result.type_ids_data[values_list_offset + values_offset] = + static_cast(VariantLogicalType::VARCHAR); + } + auto string_data = primitive.GetValueUnsafe(); + auto string_size = string_data.GetSize(); + VarintEncode(static_cast(string_size), blob_data + data_offset); + data_offset += GetVarintSize(string_size); + memcpy(blob_data + data_offset, string_data.GetData(), string_size); + data_offset += string_size; + break; + } + default: + throw InternalException("Encountered unrecognized LogicalType in VariantValue::ConvertValue: %s", + primitive.type().ToString()); + } + values_offset++; + break; + } + default: + throw InternalException("VariantValueType not handled"); + } +} + +//! Copied and modified from 'to_variant.cpp' +static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVector &keys_selvec, idx_t &selvec_size) { + auto &keys = VariantVector::GetKeys(result); + auto keys_data = ListVector::GetData(keys); + + auto &children = VariantVector::GetChildren(result); + auto children_data = ListVector::GetData(children); + + auto &values = VariantVector::GetValues(result); + auto values_data = ListVector::GetData(values); + + auto &blob = VariantVector::GetData(result); + auto blob_data = FlatVector::GetData(blob); + + idx_t children_offset = 0; + idx_t values_offset = 0; + idx_t keys_offset = 0; + + auto keys_sizes = variant::OffsetData::GetKeys(offsets); + auto children_sizes = variant::OffsetData::GetChildren(offsets); + auto values_sizes = variant::OffsetData::GetValues(offsets); + auto blob_sizes = variant::OffsetData::GetBlob(offsets); + + auto count = offsets.size(); + for (idx_t i = 0; i < count; i++) { + auto &keys_entry = keys_data[i]; + auto &children_entry = children_data[i]; + auto &values_entry = values_data[i]; + + //! keys + keys_entry.length = keys_sizes[i]; + keys_entry.offset = keys_offset; + keys_offset += keys_entry.length; + + //! children + children_entry.length = children_sizes[i]; + children_entry.offset = children_offset; + children_offset += children_entry.length; + + //! values + values_entry.length = values_sizes[i]; + values_entry.offset = values_offset; + values_offset += values_entry.length; + + //! value + blob_data[i] = StringVector::EmptyString(blob, blob_sizes[i]); + } + + //! Reserve for the children of the lists + ListVector::Reserve(keys, keys_offset); + ListVector::Reserve(children, children_offset); + ListVector::Reserve(values, values_offset); + + //! Set list sizes + ListVector::SetListSize(keys, keys_offset); + ListVector::SetListSize(children, children_offset); + ListVector::SetListSize(values, values_offset); + + keys_selvec.Initialize(keys_offset); + selvec_size = keys_offset; +} + +void VariantValue::ToVARIANT(vector &input, Vector &result) { + auto count = input.size(); + if (input.empty()) { + return; + } + + //! Keep track of all the offsets for each row. + DataChunk analyze_offsets; + analyze_offsets.Initialize( + Allocator::DefaultAllocator(), + {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); + analyze_offsets.SetCardinality(count); + variant::InitializeOffsets(analyze_offsets, count); + + for (idx_t i = 0; i < count; i++) { + auto &value = input[i]; + if (value.IsNull()) { + continue; + } + AnalyzeValue(value, i, analyze_offsets); + } + + SelectionVector keys_selvec; + idx_t keys_selvec_size; + InitializeVariants(analyze_offsets, result, keys_selvec, keys_selvec_size); + + auto &keys = VariantVector::GetKeys(result); + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + + DataChunk conversion_offsets; + conversion_offsets.Initialize( + Allocator::DefaultAllocator(), + {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); + conversion_offsets.SetCardinality(count); + variant::InitializeOffsets(conversion_offsets, count); + + VariantVectorData variant_data(result); + for (idx_t i = 0; i < count; i++) { + auto &value = input[i]; + if (value.IsNull()) { + FlatVector::SetNull(result, i, true); + continue; + } + ConvertValue(value, variant_data, i, conversion_offsets, keys_selvec, dictionary); + } + +#ifdef DEBUG + { + auto conversion_keys_offset = variant::OffsetData::GetKeys(conversion_offsets); + auto conversion_children_offset = variant::OffsetData::GetChildren(conversion_offsets); + auto conversion_values_offset = variant::OffsetData::GetValues(conversion_offsets); + auto conversion_data_offset = variant::OffsetData::GetBlob(conversion_offsets); + + auto analyze_keys_offset = variant::OffsetData::GetKeys(analyze_offsets); + auto analyze_children_offset = variant::OffsetData::GetChildren(analyze_offsets); + auto analyze_values_offset = variant::OffsetData::GetValues(analyze_offsets); + auto analyze_data_offset = variant::OffsetData::GetBlob(analyze_offsets); + + for (idx_t i = 0; i < count; i++) { + D_ASSERT(conversion_keys_offset[i] == analyze_keys_offset[i]); + D_ASSERT(conversion_children_offset[i] == analyze_children_offset[i]); + D_ASSERT(conversion_values_offset[i] == analyze_values_offset[i]); + D_ASSERT(conversion_data_offset[i] == analyze_data_offset[i]); + } + } + +#endif + + //! Finalize the 'data' column of the VARIANT + auto conversion_data_offsets = variant::OffsetData::GetBlob(conversion_offsets); + for (idx_t i = 0; i < count; i++) { + auto &data = variant_data.blob_data[i]; + data.SetSizeAndFinalize(conversion_data_offsets[i], conversion_data_offsets[i]); + } + + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); + + keys_entry.Slice(keys_selvec, keys_selvec_size); + keys_entry.Flatten(keys_selvec_size); + + if (input.size() == 1) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +yyjson_mut_val *VariantValue::ToJSON(ClientContext &context, yyjson_mut_doc *doc) const { + switch (value_type) { + case VariantValueType::PRIMITIVE: { + if (primitive_value.IsNull()) { + return yyjson_mut_null(doc); + } + switch (primitive_value.type().id()) { + case LogicalTypeId::BOOLEAN: { + if (primitive_value.GetValue()) { + return yyjson_mut_true(doc); + } else { + return yyjson_mut_false(doc); + } + } + case LogicalTypeId::TINYINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::SMALLINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::INTEGER: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::BIGINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::FLOAT: + return yyjson_mut_real(doc, primitive_value.GetValue()); + case LogicalTypeId::DOUBLE: + return yyjson_mut_real(doc, primitive_value.GetValue()); + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::VARCHAR: { + auto value_str = primitive_value.ToString(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP: { + auto value_str = primitive_value.ToString(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP_NS: { + auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + default: + throw InternalException("Unexpected primitive type: %s", primitive_value.type().ToString()); + } + } + case VariantValueType::OBJECT: { + auto obj = yyjson_mut_obj(doc); + for (const auto &it : object_children) { + auto &key = it.first; + auto value = it.second.ToJSON(context, doc); + yyjson_mut_obj_add_val(doc, obj, key.c_str(), value); + } + return obj; + } + case VariantValueType::ARRAY: { + auto arr = yyjson_mut_arr(doc); + for (auto &item : array_items) { + auto value = item.ToJSON(context, doc); + yyjson_mut_arr_add_val(arr, value); + } + return arr; + } + default: + throw InternalException("Can't serialize this VariantValue type to JSON"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/variant/variant_value_convert.cpp b/src/duckdb/src/common/types/variant/variant_value_convert.cpp new file mode 100644 index 000000000..1d0ac274e --- /dev/null +++ b/src/duckdb/src/common/types/variant/variant_value_convert.cpp @@ -0,0 +1,55 @@ +#include "duckdb/function/variant/variant_value_convert.hpp" + +namespace duckdb { + +template <> +Value ValueConverter::VisitInteger(int8_t val) { + return Value::TINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int16_t val) { + return Value::SMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(int32_t val) { + return Value::INTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(int64_t val) { + return Value::BIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(hugeint_t val) { + return Value::HUGEINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint8_t val) { + return Value::UTINYINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint16_t val) { + return Value::USMALLINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uint32_t val) { + return Value::UINTEGER(val); +} + +template <> +Value ValueConverter::VisitInteger(uint64_t val) { + return Value::UBIGINT(val); +} + +template <> +Value ValueConverter::VisitInteger(uhugeint_t val) { + return Value::UHUGEINT(val); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index ad27b162d..7363f952a 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -32,7 +32,8 @@ namespace duckdb { UnifiedVectorFormat::UnifiedVectorFormat() : sel(nullptr), data(nullptr), physical_type(PhysicalType::INVALID) { } -UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept : sel(nullptr), data(nullptr) { +UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept + : sel(nullptr), data(nullptr), physical_type(PhysicalType::INVALID) { bool refers_to_self = other.sel == &other.owned_sel; std::swap(sel, other.sel); std::swap(data, other.data); @@ -96,8 +97,7 @@ Vector::Vector(const Value &value) : type(value.type()) { Vector::Vector(Vector &&other) noexcept : vector_type(other.vector_type), type(std::move(other.type)), data(other.data), - validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)), - cached_hashes(std::move(other.cached_hashes)) { + validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)) { } void Vector::Reference(const Value &value) { @@ -171,7 +171,6 @@ void Vector::Reinterpret(const Vector &other) { auxiliary = make_shared_ptr(std::move(new_vector)); } else { AssignSharedPointer(auxiliary, other.auxiliary); - AssignSharedPointer(cached_hashes, other.cached_hashes); } data = other.data; validity = other.validity; @@ -235,6 +234,9 @@ void Vector::Slice(const Vector &other, const SelectionVector &sel, idx_t count) } void Vector::Slice(const SelectionVector &sel, idx_t count) { + if (!sel.IsSet() || count == 0) { + return; // Nothing to do here + } if (GetVectorType() == VectorType::CONSTANT_VECTOR) { // dictionary on a constant is just a constant return; @@ -276,7 +278,6 @@ void Vector::Slice(const SelectionVector &sel, idx_t count) { vector_type = VectorType::DICTIONARY_VECTOR; buffer = std::move(dict_buffer); auxiliary = std::move(child_ref); - cached_hashes.reset(); } void Vector::Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t count) { @@ -287,15 +288,25 @@ void Vector::Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t } void Vector::Dictionary(Vector &dict, idx_t dictionary_size, const SelectionVector &sel, idx_t count) { - if (DictionaryVector::CanCacheHashes(dict.GetType()) && !dict.cached_hashes) { - // Create an empty hash vector for this dictionary, potentially to be used for caching hashes later - // This needs to happen here, as we need to add "cached_hashes" to the original input Vector "dict" - dict.cached_hashes = make_buffer(Vector(LogicalType::HASH, false, false, 0)); - } Reference(dict); Dictionary(dictionary_size, sel, count); } +void Vector::Dictionary(buffer_ptr reusable_dict, const SelectionVector &sel) { + D_ASSERT(type.InternalType() != PhysicalType::STRUCT); + D_ASSERT(type == reusable_dict->data.GetType()); + vector_type = VectorType::DICTIONARY_VECTOR; + data = reusable_dict->data.data; + validity.Reset(); + + auto dict_buffer = make_buffer(sel); + dict_buffer->SetDictionarySize(reusable_dict->size.GetIndex()); + dict_buffer->SetDictionaryId(reusable_dict->id); + buffer = std::move(dict_buffer); + + auxiliary = std::move(reusable_dict); +} + void Vector::Slice(const SelectionVector &sel, idx_t count, SelCache &cache) { if (GetVectorType() == VectorType::DICTIONARY_VECTOR && GetType().InternalType() != PhysicalType::STRUCT) { // dictionary vector: need to merge dictionaries @@ -353,7 +364,6 @@ void Vector::Initialize(bool initialize_to_zero, idx_t capacity) { } void Vector::FindResizeInfos(vector &resize_infos, const idx_t multiplier) { - ResizeInfo resize_info(*this, data, buffer.get(), multiplier); resize_infos.emplace_back(resize_info); @@ -724,6 +734,10 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { auto str = reinterpret_cast(data)[index]; return Value::BIGNUM(const_data_ptr_cast(str.data.GetData()), str.data.GetSize()); } + case LogicalTypeId::GEOMETRY: { + auto str = reinterpret_cast(data)[index]; + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } case LogicalTypeId::AGGREGATE_STATE: { auto str = reinterpret_cast(data)[index]; return Value::AGGREGATE_STATE(vector->GetType(), const_data_ptr_cast(str.GetData()), str.GetSize()); @@ -802,7 +816,6 @@ Value Vector::GetValue(const Vector &v_p, idx_t index_p) { value.GetTypeMutable().CopyAuxInfo(v_p.GetType()); } if (v_p.GetType().id() != LogicalTypeId::AGGREGATE_STATE && value.type().id() != LogicalTypeId::AGGREGATE_STATE) { - D_ASSERT(v_p.GetType() == value.type()); } return value; @@ -908,7 +921,6 @@ idx_t Vector::GetAllocationSize(idx_t cardinality) const { } default: throw NotImplementedException("Vector::GetAllocationSize not implemented for type: %s", type.ToString()); - break; } } @@ -1219,7 +1231,6 @@ void Vector::ToUnifiedFormat(idx_t count, UnifiedVectorFormat &format) { } void Vector::RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data) { - input.ToUnifiedFormat(count, data.unified); data.logical_type = input.GetType(); @@ -1846,7 +1857,9 @@ void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { inverted_sel.set_index(offset++, current_index); inverted_sel.set_index(offset++, current_index); } - Vector inverted_vector(vector, inverted_sel, verify_count); + auto reusable_dict = DictionaryVector::CreateReusableDictionary(vector.type, verify_count); + auto &inverted_vector = reusable_dict->data; + inverted_vector.Slice(vector, inverted_sel, verify_count); inverted_vector.Flatten(verify_count); // now insert the NULL values at every other position for (idx_t i = 0; i < count; i++) { @@ -1860,8 +1873,13 @@ void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { original_sel.set_index(offset++, verify_count - 1 - i * 2); } // now slice the inverted vector with the inverted selection vector - vector.Dictionary(inverted_vector, verify_count, original_sel, count); - DictionaryVector::SetDictionaryId(vector, UUID::ToString(UUID::GenerateRandomUUID())); + if (vector.GetType().InternalType() == PhysicalType::STRUCT) { + // Reusable dictionary API does not work for STRUCT + vector.Dictionary(inverted_vector, verify_count, original_sel, count); + vector.buffer->Cast().SetDictionaryId(reusable_dict->id); + } else { + vector.Dictionary(reusable_dict, original_sel); + } vector.Verify(count); } @@ -1922,17 +1940,27 @@ void Vector::DebugShuffleNestedVector(Vector &vector, idx_t count) { //===--------------------------------------------------------------------===// // DictionaryVector //===--------------------------------------------------------------------===// +buffer_ptr DictionaryVector::CreateReusableDictionary(const LogicalType &type, const idx_t &size) { + auto res = make_buffer(Vector(type, size)); + res->size = size; + res->id = UUID::ToString(UUID::GenerateRandomUUID()); + return res; +} + const Vector &DictionaryVector::GetCachedHashes(Vector &input) { D_ASSERT(CanCacheHashes(input)); - auto &dictionary = Child(input); - auto &dictionary_hashes = dictionary.cached_hashes->Cast().data; - if (!dictionary_hashes.data) { + + auto &child = input.auxiliary->Cast(); + lock_guard guard(child.cached_hashes_lock); + + if (!child.cached_hashes.data) { // Uninitialized: hash the dictionary - const auto dictionary_count = DictionarySize(input).GetIndex(); - dictionary_hashes.Initialize(false, dictionary_count); - VectorOperations::Hash(dictionary, dictionary_hashes, dictionary_count); + const auto dictionary_size = DictionarySize(input).GetIndex(); + D_ASSERT(!child.size.IsValid() || child.size.GetIndex() == dictionary_size); + child.cached_hashes.Initialize(false, dictionary_size); + VectorOperations::Hash(child.data, child.cached_hashes, dictionary_size); } - return dictionary.cached_hashes->Cast().data; + return child.cached_hashes; } //===--------------------------------------------------------------------===// @@ -2317,7 +2345,6 @@ const Vector &MapVector::GetValues(const Vector &vector) { } MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); // unify the MAP vector, which is a physical LIST vector @@ -2332,7 +2359,6 @@ MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const Sel keys.ToUnifiedFormat(maps_length, key_data); for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto mapped_row = sel.get_index(row_idx); auto map_idx = map_data.sel->get_index(mapped_row); @@ -2503,7 +2529,6 @@ void ListVector::PushBack(Vector &target, const Value &insert) { } idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { - auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); if (info.needs_slicing) { SelectionVector sel(info.child_list_info.length); @@ -2516,7 +2541,6 @@ idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t of } ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { - ConsecutiveChildListInfo info; UnifiedVectorFormat unified_list_data; list.ToUnifiedFormat(offset + count, unified_list_data); diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp index ac4c88c01..775757c59 100644 --- a/src/duckdb/src/common/value_operations/comparison_operations.cpp +++ b/src/duckdb/src/common/value_operations/comparison_operations.cpp @@ -141,6 +141,11 @@ static bool TemplatedBooleanOperation(const Value &left, const Value &right) { auto &right_children = StructValue::GetChildren(right); // this should be enforced by the type D_ASSERT(left_children.size() == right_children.size()); + if (left_children.empty()) { + const auto const_true = Value::BOOLEAN(true); + return ValuePositionComparator::Final(const_true, const_true); + } + idx_t i = 0; for (; i < left_children.size() - 1; ++i) { if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp index e57f9738d..d08a2d276 100644 --- a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp +++ b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp @@ -289,6 +289,7 @@ template idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel, optional_ptr null_mask) { if (!sel) { + D_ASSERT(count <= STANDARD_VECTOR_SIZE); sel = FlatVector::IncrementalSelectionVector(); } @@ -468,7 +469,6 @@ using StructEntries = vector>; void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, OptionalSelection &opt) { - for (idx_t i = 0; i < count;) { const auto slice_idx = slice_sel.get_index(i); const auto result_idx = sel.get_index(slice_idx); @@ -478,21 +478,21 @@ void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, } void ExtractNestedMask(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - ValidityMask *child_mask, optional_ptr null_mask) { - - if (!child_mask) { + ValidityMask *child_mask_p, optional_ptr null_mask) { + if (!child_mask_p) { return; } + auto &child_mask = *child_mask_p; for (idx_t i = 0; i < count; ++i) { const auto slice_idx = slice_sel.get_index(i); const auto result_idx = sel.get_index(slice_idx); - if (child_mask && !child_mask->RowIsValid(slice_idx)) { + if (!child_mask.RowIsValid(slice_idx)) { null_mask->SetInvalid(result_idx); } } - child_mask->Reset(null_mask->Capacity()); + child_mask.Reset(null_mask->Capacity()); } void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { @@ -767,8 +767,6 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select return count; } - // FIXME: This function can probably be optimized since we know the array size is fixed for every entry. - D_ASSERT(ArrayType::GetSize(left.GetType()) == ArrayType::GetSize(right.GetType())); auto array_size = ArrayType::GetSize(left.GetType()); @@ -808,39 +806,13 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select } idx_t match_count = 0; - for (idx_t pos = 0; count > 0; ++pos) { + for (idx_t pos = 0; pos < array_size && count > 0; ++pos) { // Set up the cursors for the current position PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - // Tie-break the pairs where one of the LISTs is exhausted. idx_t true_count = 0; idx_t false_count = 0; - idx_t maybe_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - if (array_size == pos) { - const auto idx = sel.get_index(slice_idx); - if (PositionComparator::TieBreak(array_size, array_size)) { - true_opt.Append(true_count, idx); - } else { - false_opt.Append(false_count, idx); - } - } else { - true_sel.set_index(maybe_count++, slice_idx); - } - } - true_opt.Advance(true_count); - false_opt.Advance(false_count); - match_count += true_count; - - // Redensify the list cursors - if (maybe_count < count) { - count = maybe_count; - DensifyNestedSelection(true_sel, count, slice_sel); - PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); - PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - } // Find everything that definitely matches true_count = @@ -878,6 +850,15 @@ idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const Select count = true_count; } + if (count > 0) { + if (PositionComparator::TieBreak(array_size, array_size)) { + ExtractNestedSelection(slice_sel, count, sel, true_opt); + match_count += count; + } else { + ExtractNestedSelection(slice_sel, count, sel, false_opt); + } + } + return match_count; } @@ -890,6 +871,7 @@ idx_t DistinctSelectNested(Vector &left, Vector &right, optional_ptr(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, false_opt, null_mask); - switch (left.GetType().InternalType()) { + auto &left_type = left.GetType(); + switch (left_type.InternalType()) { case PhysicalType::LIST: match_count += DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); @@ -1009,7 +992,6 @@ template idx_t TemplatedDistinctSelectOperation(Vector &left, Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - switch (left.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: diff --git a/src/duckdb/src/common/vector_operations/vector_copy.cpp b/src/duckdb/src/common/vector_operations/vector_copy.cpp index af75d56b9..2b333bc99 100644 --- a/src/duckdb/src/common/vector_operations/vector_copy.cpp +++ b/src/duckdb/src/common/vector_operations/vector_copy.cpp @@ -39,7 +39,6 @@ static const ValidityMask &ExtractValidityMask(const Vector &v) { void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, idx_t source_offset, idx_t target_offset, idx_t copy_count) { - SelectionVector owned_sel; const SelectionVector *sel = &sel_p; diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp index 7940d0120..559fea635 100644 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -34,8 +34,9 @@ unique_ptr VirtualFileSystem::OpenFileExtended(const OpenFileInfo &f } } // open the base file handle in UNCOMPRESSED mode + flags.SetCompression(FileCompressionType::UNCOMPRESSED); - auto file_handle = FindFileSystem(file.path).OpenFile(file, flags, opener); + auto file_handle = FindFileSystem(file.path, opener).OpenFile(file, flags, opener); if (!file_handle) { return nullptr; } @@ -87,6 +88,9 @@ string VirtualFileSystem::GetVersionTag(FileHandle &handle) { FileType VirtualFileSystem::GetFileType(FileHandle &handle) { return handle.file_system.GetFileType(handle); } +FileMetadata VirtualFileSystem::Stats(FileHandle &handle) { + return handle.file_system.Stats(handle); +} void VirtualFileSystem::Truncate(FileHandle &handle, int64_t new_size) { handle.file_system.Truncate(handle, new_size); @@ -111,7 +115,7 @@ void VirtualFileSystem::RemoveDirectory(const string &directory, optional_ptr &callback, optional_ptr opener) { - return FindFileSystem(directory).ListFiles(directory, callback, opener); + return FindFileSystem(directory, opener).ListFiles(directory, callback, opener); } void VirtualFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -119,7 +123,7 @@ void VirtualFileSystem::MoveFile(const string &source, const string &target, opt } bool VirtualFileSystem::FileExists(const string &filename, optional_ptr opener) { - return FindFileSystem(filename).FileExists(filename, opener); + return FindFileSystem(filename, opener).FileExists(filename, opener); } bool VirtualFileSystem::IsPipe(const string &filename, optional_ptr opener) { @@ -139,7 +143,7 @@ string VirtualFileSystem::PathSeparator(const string &path) { } vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { - return FindFileSystem(path).Glob(path, opener); + return FindFileSystem(path, opener).Glob(path, opener); } void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { @@ -224,16 +228,72 @@ bool VirtualFileSystem::SubSystemIsDisabled(const string &name) { return disabled_file_systems.find(name) != disabled_file_systems.end(); } +bool VirtualFileSystem::IsDisabledForPath(const string &path) { + if (disabled_file_systems.empty()) { + return false; + } + auto fs = FindFileSystemInternal(path); + if (!fs) { + fs = default_fs.get(); + } + return disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end(); +} + +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr opener) { + return FindFileSystem(path, FileOpener::TryGetDatabase(opener)); +} + +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr db_instance) { + auto fs = FindFileSystemInternal(path); + + if (!fs && db_instance) { + string required_extension; + + for (const auto &entry : EXTENSION_FILE_PREFIXES) { + if (StringUtil::StartsWith(path, entry.name)) { + required_extension = entry.extension; + } + } + if (!required_extension.empty() && db_instance && !db_instance->ExtensionIsLoaded(required_extension)) { + auto &dbconfig = DBConfig::GetConfig(*db_instance); + if (!ExtensionHelper::CanAutoloadExtension(required_extension) || + !dbconfig.options.autoload_known_extensions) { + auto error_message = "File " + path + " requires the extension " + required_extension + " to be loaded"; + error_message = + ExtensionHelper::AddExtensionInstallHintToErrorMsg(*db_instance, error_message, required_extension); + throw MissingExtensionException(error_message); + } + // an extension is required to read this file, but it is not loaded - try to load it + ExtensionHelper::AutoLoadExtension(*db_instance, required_extension); + } + + // Retry after having autoloaded + fs = FindFileSystem(path); + } + + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); + } + return *fs; +} + FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { - auto &fs = FindFileSystemInternal(path); - if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { - throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); + auto fs = FindFileSystemInternal(path); + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); } - return fs; + return *fs; } -FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { +optional_ptr VirtualFileSystem::FindFileSystemInternal(const string &path) { FileSystem *fs = nullptr; + for (auto &sub_system : sub_systems) { if (sub_system->CanHandleFile(path)) { if (sub_system->IsManuallySet()) { @@ -245,7 +305,9 @@ FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { if (fs) { return *fs; } - return *default_fs; + + // We could use default_fs, that's on the caller + return nullptr; } } // namespace duckdb diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp index 87c3bc661..2b51f5cf0 100644 --- a/src/duckdb/src/execution/aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/aggregate_hashtable.cpp @@ -48,7 +48,6 @@ GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context_p, A : BaseAggregateHashTable(context_p, allocator, aggregate_objects_p, std::move(payload_types_p)), context(context_p), radix_bits(radix_bits), count(0), capacity(0), sink_count(0), skip_lookups(false), enable_hll(false), aggregate_allocator(make_shared_ptr(allocator)), state(*aggregate_allocator) { - // Append hash column to the end and initialise the row layout group_types_p.emplace_back(LogicalType::HASH); @@ -76,8 +75,8 @@ void GroupedAggregateHashTable::InitializePartitionedData() { if (!partitioned_data || RadixPartitioning::RadixBitsOfPowerOfTwo(partitioned_data->PartitionCount()) != radix_bits) { D_ASSERT(!partitioned_data || partitioned_data->Count() == 0); - partitioned_data = - make_uniq(buffer_manager, layout_ptr, radix_bits, layout_ptr->ColumnCount() - 1); + partitioned_data = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, + radix_bits, layout_ptr->ColumnCount() - 1); } else { partitioned_data->Reset(); } @@ -93,8 +92,8 @@ void GroupedAggregateHashTable::InitializePartitionedData() { void GroupedAggregateHashTable::InitializeUnpartitionedData() { D_ASSERT(radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD); if (!unpartitioned_data) { - unpartitioned_data = - make_uniq(buffer_manager, layout_ptr, 0ULL, layout_ptr->ColumnCount() - 1); + unpartitioned_data = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, + 0ULL, layout_ptr->ColumnCount() - 1); } else { unpartitioned_data->Reset(); } diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index ec11c1289..0278df506 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -181,6 +181,8 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co } else { VectorOperations::DefaultCast(vector, intermediate, count, true); } + intermediate.Verify(count); + //! FIXME: this is probably also where we want to test 'variant_normalize' Vector result(vector.GetType(), true, false, count); //! Then cast back into the original type @@ -190,6 +192,7 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co VectorOperations::DefaultCast(intermediate, result, count, true); } vector.Reference(result); + vector.Verify(count); } } @@ -227,7 +230,6 @@ void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, // The result vector must be used for the first time, or must be reset. // Otherwise, the validity mask can contain previous (now incorrect) data. if (result.GetVectorType() == VectorType::FLAT_VECTOR) { - // We do not initialize vector caches for these expressions. if (expr.GetExpressionClass() != ExpressionClass::BOUND_REF && expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT && diff --git a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp index 6e78de49c..f55c26df3 100644 --- a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp @@ -138,8 +138,19 @@ static idx_t TemplatedSelectOperation(Vector &left, Vector &right, optional_ptr< false_sel.get()); case PhysicalType::LIST: case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - return NestedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); + case PhysicalType::ARRAY: { + auto result_count = NestedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); + if (true_sel && result_count > 0) { + std::sort(true_sel->data(), true_sel->data() + result_count); + } + if (false_sel) { + idx_t false_count = count - result_count; + if (false_count > 0) { + std::sort(false_sel->data(), false_sel->data() + false_count); + } + } + return result_count; + } default: throw InternalException("Invalid type for comparison"); } @@ -209,7 +220,6 @@ idx_t NestedSelector::Select(Vector &left, Vector &ri static inline idx_t SelectNotNull(Vector &left, Vector &right, const idx_t count, const SelectionVector &sel, SelectionVector &maybe_vec, OptionalSelection &false_opt, optional_ptr null_mask) { - UnifiedVectorFormat lvdata, rvdata; left.ToUnifiedFormat(count, lvdata); right.ToUnifiedFormat(count, rvdata); diff --git a/src/duckdb/src/execution/expression_executor/execute_function.cpp b/src/duckdb/src/execution/expression_executor/execute_function.cpp index a7e99287b..d3610dae2 100644 --- a/src/duckdb/src/execution/expression_executor/execute_function.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_function.cpp @@ -71,7 +71,7 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp return false; // Dictionary is too large, bail } - if (input_dictionary_id != current_input_dictionary_id) { + if (!output_dictionary || current_input_dictionary_id != input_dictionary_id) { // We haven't seen this dictionary before const auto chunk_fill_ratio = static_cast(args.size()) / STANDARD_VECTOR_SIZE; if (input_dictionary_size > STANDARD_VECTOR_SIZE && chunk_fill_ratio <= CHUNK_FILL_RATIO_THRESHOLD) { @@ -82,9 +82,8 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp } // We can do dictionary optimization! Re-initialize + output_dictionary = DictionaryVector::CreateReusableDictionary(result.GetType(), input_dictionary_size); current_input_dictionary_id = input_dictionary_id; - output_dictionary = make_uniq(result.GetType(), input_dictionary_size); - output_dictionary_id = UUID::ToString(UUID::GenerateRandomUUID()); // Set up the input chunk DataChunk input_chunk; @@ -105,16 +104,14 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp input_chunk.SetCardinality(count); // Execute, storing the result in an intermediate vector, and copying it to the output dictionary - Vector output_intermediate(output_dictionary->GetType()); - expr.function.function(input_chunk, state, output_intermediate); - VectorOperations::Copy(output_intermediate, *output_dictionary, count, 0, offset); + Vector output_intermediate(result.GetType()); + expr.function.GetFunctionCallback()(input_chunk, state, output_intermediate); + VectorOperations::Copy(output_intermediate, output_dictionary->data, count, 0, offset); } } - // Create a dictionary result vector and give it an ID - const auto &input_sel_vector = DictionaryVector::SelVector(unary_input); - result.Dictionary(*output_dictionary, input_dictionary_size, input_sel_vector, args.size()); - DictionaryVector::SetDictionaryId(result, output_dictionary_id); + // Result references the dictionary + result.Dictionary(output_dictionary, DictionaryVector::SelVector(unary_input)); return true; } @@ -127,15 +124,15 @@ unique_ptr ExpressionExecutor::InitializeState(const BoundFunct } result->Finalize(); - if (expr.function.init_local_state) { - result->local_state = expr.function.init_local_state(*result, expr, expr.bind_info.get()); + if (expr.function.HasInitStateCallback()) { + result->local_state = expr.function.GetInitStateCallback()(*result, expr, expr.bind_info.get()); } return std::move(result); } static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { #ifdef DEBUG - if (args.data.empty() || expr.function.null_handling != FunctionNullHandling::DEFAULT_NULL_HANDLING) { + if (args.data.empty() || expr.function.GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING) { return; } @@ -184,10 +181,10 @@ void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, Expression arguments.SetCardinality(count); arguments.Verify(); - D_ASSERT(expr.function.function); + D_ASSERT(expr.function.HasFunctionCallback()); auto &execute_function_state = state->Cast(); if (!execute_function_state.TryExecuteDictionaryExpression(expr, arguments, *state, result)) { - expr.function.function(arguments, *state, result); + expr.function.GetFunctionCallback()(arguments, *state, result); } VerifyNullHandling(expr, arguments, result); diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index 87c9cbf9b..300ad7bbc 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -50,7 +50,6 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co const IndexStorageInfo &info) : BoundIndex(name, ART::TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db), allocators(allocators_ptr), owns_data(false), verify_max_key_len(false) { - // FIXME: Use the new byte representation function to support nested types. for (idx_t i = 0; i < types.size(); i++) { switch (types[i]) { @@ -522,7 +521,9 @@ ErrorData ART::Insert(IndexLock &l, DataChunk &chunk, Vector &row_ids, IndexAppe if (keys[i].Empty()) { continue; } - D_ASSERT(ARTOperator::Lookup(*this, tree, keys[i], 0)); + auto leaf = ARTOperator::Lookup(*this, tree, keys[i], 0); + D_ASSERT(leaf); + D_ASSERT(ARTOperator::LookupInLeaf(*this, *leaf, row_id_keys[i])); } #endif return ErrorData(); @@ -602,8 +603,9 @@ void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { continue; } auto leaf = ARTOperator::Lookup(*this, tree, keys[i], 0); - if (leaf && leaf->GetType() == NType::LEAF_INLINED) { - D_ASSERT(leaf->GetRowId() != row_id_keys[i].GetRowId()); + if (leaf) { + auto contains_row_id = ARTOperator::LookupInLeaf(*this, *leaf, row_id_keys[i]); + D_ASSERT(!contains_row_id); } } #endif @@ -634,7 +636,7 @@ bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &ro Iterator it(*this); // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, key, equal, 0)) { + if (!it.LowerBound(tree, key, equal)) { return true; } @@ -667,7 +669,7 @@ bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_e Iterator it(*this); // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { + if (!it.LowerBound(tree, lower_bound, left_equal)) { return true; } @@ -942,6 +944,8 @@ IndexStorageInfo ART::PrepareSerialize(const case_insensitive_map_t &opti } IndexStorageInfo ART::SerializeToDisk(QueryContext context, const case_insensitive_map_t &options) { + lock_guard guard(lock); + // If the storage format uses deprecated leaf storage, // then we need to transform all nested leaves before serialization. auto v1_0_0_option = options.find("v1_0_0_storage"); @@ -1047,7 +1051,16 @@ idx_t ART::GetInMemorySize(IndexLock &index_lock) { return in_memory_size; } -//===--------------------------------------------------------------------===// +bool ART::RequiresTransactionality() const { + return true; +} + +unique_ptr ART::CreateEmptyCopy(const string &name_prefix, IndexConstraintType constraint_type) const { + return make_uniq(name_prefix + name, constraint_type, GetColumnIds(), table_io_manager, unbound_expressions, + db); +} + +//===-------------------------------------------------------------------===// // Vacuum //===--------------------------------------------------------------------===// @@ -1205,17 +1218,27 @@ bool ART::MergeIndexes(IndexLock &state, BoundIndex &other_index) { // Verification //===--------------------------------------------------------------------===// -string ART::VerifyAndToString(IndexLock &l, const bool only_verify) { - return VerifyAndToStringInternal(only_verify); +string ART::ToString(IndexLock &l, bool display_ascii) { + return ToStringInternal(display_ascii); } -string ART::VerifyAndToStringInternal(const bool only_verify) { +string ART::ToStringInternal(bool display_ascii) { if (tree.HasMetadata()) { - return "ART: " + tree.VerifyAndToString(*this, only_verify); + return "\nART: \n" + tree.ToString(*this, ToStringOptions(0, false, display_ascii, nullptr, 0, 0, true, false)); } return "[empty]"; } +void ART::Verify(IndexLock &l) { + VerifyInternal(); +} + +void ART::VerifyInternal() { + if (tree.HasMetadata()) { + tree.Verify(*this); + } +} + void ART::VerifyAllocations(IndexLock &l) { return VerifyAllocationsInternal(); } diff --git a/src/duckdb/src/execution/index/art/art_merger.cpp b/src/duckdb/src/execution/index/art/art_merger.cpp index 70781cbfb..61d2ec317 100644 --- a/src/duckdb/src/execution/index/art/art_merger.cpp +++ b/src/duckdb/src/execution/index/art/art_merger.cpp @@ -217,9 +217,6 @@ void ARTMerger::MergeNodeAndPrefix(Node &node, Node &prefix, const GateStatus pa auto child = node.GetChildMutable(art, byte); // Reduce the prefix to the bytes after pos. - // We always reduce by at least one byte, - // thus, if the prefix was a gate, it no longer is. - prefix.SetGateStatus(GateStatus::GATE_NOT_SET); Prefix::Reduce(art, prefix, pos); if (child) { diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp index a694ca3b5..4a9332fc9 100644 --- a/src/duckdb/src/execution/index/art/base_leaf.cpp +++ b/src/duckdb/src/execution/index/art/base_leaf.cpp @@ -30,8 +30,10 @@ void BaseLeaf::InsertByteInternal(BaseLeaf &n, const uint8_t byt } template -BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, Node &node, const uint8_t byte) { - auto &n = Node::Ref(art, node, node.GetType()); +NodeHandle> BaseLeaf::DeleteByteInternal(ART &art, Node &node, + const uint8_t byte) { + NodeHandle> handle(art, node); + auto &n = handle.Get(); uint8_t child_pos = 0; for (; child_pos < n.count; child_pos++) { @@ -45,7 +47,7 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, for (uint8_t i = child_pos; i < n.count; i++) { n.key[i] = n.key[i + 1]; } - return n; + return handle; } //===--------------------------------------------------------------------===// @@ -53,27 +55,36 @@ BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, //===--------------------------------------------------------------------===// void Node7Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node15. - auto &n7 = Node::Ref(art, node, NODE_7_LEAF); - if (n7.count == CAPACITY) { - auto node7 = node; - Node15Leaf::GrowNode7Leaf(art, node, node7); - Node15Leaf::InsertByte(art, node, byte); - return; - } + { + NodeHandle handle(art, node); + auto &n7 = handle.Get(); - InsertByteInternal(n7, byte); + if (n7.count != CAPACITY) { + InsertByteInternal(n7, byte); + return; + } + } + // The node is full. Grow to Node15. + auto node7 = node; + Node15Leaf::GrowNode7Leaf(art, node, node7); + Node15Leaf::InsertByte(art, node, byte); } void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byte, const ARTKey &row_id) { - auto &n7 = DeleteByteInternal(art, node, byte); + idx_t remainder; + { + auto n7_handle = DeleteByteInternal(art, node, byte); + auto &n7 = n7_handle.Get(); + + if (n7.count != 1) { + return; + } - // Compress one-way nodes. - if (n7.count == 1) { + // Compress one-way nodes. D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); // Get the remaining row ID. - auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; + remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; remainder |= UnsafeNumericCast(n7.key[0]); // Free the prefix (nodes) and inline the remainder. @@ -82,23 +93,27 @@ void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byt Leaf::New(prefix, UnsafeNumericCast(remainder)); return; } - - // Free the Node7Leaf and inline the remainder. - Node::FreeNode(art, node); - Leaf::New(node, UnsafeNumericCast(remainder)); } + // Free the Node7Leaf and inline the remainder. + Node::FreeNode(art, node); + Leaf::New(node, UnsafeNumericCast(remainder)); } void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) { - auto &n7 = New(art, node7_leaf); - auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); - node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + { + auto n7_handle = New(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n7.count = n15.count; - for (uint8_t i = 0; i < n15.count; i++) { - n7.key[i] = n15.key[i]; - } + NodeHandle n15_handle(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); + + n7.count = n15.count; + for (uint8_t i = 0; i < n15.count; i++) { + n7.key[i] = n15.key[i]; + } + } Node::FreeNode(art, node15_leaf); } @@ -107,54 +122,66 @@ void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) //===--------------------------------------------------------------------===// void Node15Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node256Leaf. - auto &n15 = Node::Ref(art, node, NODE_15_LEAF); - if (n15.count == CAPACITY) { - auto node15 = node; - Node256Leaf::GrowNode15Leaf(art, node, node15); - Node256Leaf::InsertByte(art, node, byte); - return; + { + NodeHandle n15_handle(art, node); + auto &n15 = n15_handle.Get(); + if (n15.count != CAPACITY) { + InsertByteInternal(n15, byte); + return; + } } - - InsertByteInternal(n15, byte); + auto node15 = node; + Node256Leaf::GrowNode15Leaf(art, node, node15); + Node256Leaf::InsertByte(art, node, byte); } void Node15Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { - auto &n15 = DeleteByteInternal(art, node, byte); - - // Shrink node to Node7. - if (n15.count < Node7Leaf::CAPACITY) { - auto node15 = node; - Node7Leaf::ShrinkNode15Leaf(art, node, node15); + { + auto n15_handle = DeleteByteInternal(art, node, byte); + auto &n15 = n15_handle.Get(); + if (n15.count >= Node7Leaf::CAPACITY) { + return; + } } + auto node15 = node; + Node7Leaf::ShrinkNode15Leaf(art, node, node15); } void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { - auto &n7 = Node::Ref(art, node7_leaf, NType::NODE_7_LEAF); - auto &n15 = New(art, node15_leaf); - node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + { + NodeHandle n7_handle(art, node7_leaf); + auto &n7 = n7_handle.Get(); - n15.count = n7.count; - for (uint8_t i = 0; i < n7.count; i++) { - n15.key[i] = n7.key[i]; - } + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); + n15.count = n7.count; + for (uint8_t i = 0; i < n7.count; i++) { + n15.key[i] = n7.key[i]; + } + } Node::FreeNode(art, node7_leaf); } void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { - auto &n15 = New(art, node15_leaf); - auto &n256 = Node::Ref(art, node256_leaf, NType::NODE_256_LEAF); - node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); - - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (mask.RowIsValid(i)) { - n15.key[n15.count] = UnsafeNumericCast(i); - n15.count++; + { + auto n15_handle = New(art, node15_leaf); + auto &n15 = n15_handle.Get(); + + NodeHandle n256_handle(art, node256_leaf); + auto &n256 = n256_handle.Get(); + + node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); + + ValidityMask mask(&n256.mask[0], Node256::CAPACITY); + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (mask.RowIsValid(i)) { + n15.key[n15.count] = UnsafeNumericCast(i); + n15.count++; + } } } - Node::FreeNode(art, node256_leaf); } diff --git a/src/duckdb/src/execution/index/art/base_node.cpp b/src/duckdb/src/execution/index/art/base_node.cpp index 94d5c0fe1..a59297c2c 100644 --- a/src/duckdb/src/execution/index/art/base_node.cpp +++ b/src/duckdb/src/execution/index/art/base_node.cpp @@ -95,7 +95,9 @@ void Node4::DeleteChild(ART &art, Node &node, Node &parent, const uint8_t byte, auto prev_node4_status = node.GetGateStatus(); Node::FreeNode(art, node); - Prefix::Concat(art, parent, node, child, remaining_byte, prev_node4_status); + // Propagate both the prev_node_4 status and the general gate status (if the gate was earlier on), + // since the concatenation logic depends on both. + Prefix::Concat(art, parent, node, child, remaining_byte, prev_node4_status, status); } void Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp index 1a88b7262..c8e2d09a9 100644 --- a/src/duckdb/src/execution/index/art/iterator.cpp +++ b/src/duckdb/src/execution/index/art/iterator.cpp @@ -95,125 +95,135 @@ bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, set } void Iterator::FindMinimum(const Node &node) { - D_ASSERT(node.HasMetadata()); + reference ref(node); - // Found the minimum. - if (node.IsAnyLeaf()) { - last_leaf = node; - return; - } + while (ref.get().HasMetadata()) { + // Found the minimum. + if (ref.get().IsAnyLeaf()) { + last_leaf = ref.get(); + return; + } - // We are passing a gate node. - if (node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - status = GateStatus::GATE_SET; - entered_nested_leaf = true; - nested_depth = 0; - } + // We are passing a gate node. + if (ref.get().GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + status = GateStatus::GATE_SET; + entered_nested_leaf = true; + nested_depth = 0; + } - // Traverse the prefix. - if (node.GetType() == NType::PREFIX) { - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = prefix.data[i]; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + // Traverse the prefix. + if (ref.get().GetType() == NType::PREFIX) { + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + current_key.Push(prefix.data[i]); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = prefix.data[i]; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } } + nodes.emplace(ref.get(), 0); + ref = *prefix.ptr; + continue; } - nodes.emplace(node, 0); - return FindMinimum(*prefix.ptr); - } - // Go to the leftmost entry in the current node. - uint8_t byte = 0; - auto next = node.GetNextChild(art, byte); - D_ASSERT(next); - - // Recurse on the leftmost node. - current_key.Push(byte); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = byte; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + // Go to the leftmost entry in the current node. + uint8_t byte = 0; + auto next = ref.get().GetNextChild(art, byte); + D_ASSERT(next); + + // Move to the leftmost node. + current_key.Push(byte); + if (status == GateStatus::GATE_SET) { + row_id[nested_depth] = byte; + nested_depth++; + D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); + } + nodes.emplace(ref.get(), byte); + ref = *next; } - nodes.emplace(node, byte); - FindMinimum(*next); + // Should always have a node with metadata. + throw InternalException("ART Iterator::FindMinimum: Reached node without metadata"); } -bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { - if (!node.HasMetadata()) { - return false; - } +bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal) { + reference ref(node); + idx_t depth = 0; + + while (ref.get().HasMetadata()) { + // We found any leaf node, or a gate. + if (ref.get().IsAnyLeaf() || ref.get().GetGateStatus() == GateStatus::GATE_SET) { + D_ASSERT(status == GateStatus::GATE_NOT_SET); + D_ASSERT(current_key.Size() == key.len); + if (!equal && current_key.Contains(key)) { + return Next(); + } - // We found any leaf node, or a gate. - if (node.IsAnyLeaf() || node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - D_ASSERT(current_key.Size() == key.len); - if (!equal && current_key.Contains(key)) { - return Next(); + if (ref.get().GetGateStatus() == GateStatus::GATE_SET) { + FindMinimum(ref.get()); + } else { + last_leaf = ref.get(); + } + return true; } - if (node.GetGateStatus() == GateStatus::GATE_SET) { - FindMinimum(node); - } else { - last_leaf = node; - } - return true; - } + D_ASSERT(ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET); + if (ref.get().GetType() != NType::PREFIX) { + auto next_byte = key[depth]; + auto child = ref.get().GetNextChild(art, next_byte); - D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); - if (node.GetType() != NType::PREFIX) { - auto next_byte = key[depth]; - auto child = node.GetNextChild(art, next_byte); + // The key is greater than any key in this subtree. + if (!child) { + return Next(); + } - // The key is greater than any key in this subtree. - if (!child) { - return Next(); - } + current_key.Push(next_byte); + nodes.emplace(ref.get(), next_byte); - current_key.Push(next_byte); - nodes.emplace(node, next_byte); + // We return the minimum because all keys are greater than the lower bound. + if (next_byte > key[depth]) { + FindMinimum(*child); + return true; + } - // We return the minimum because all keys are greater than the lower bound. - if (next_byte > key[depth]) { - FindMinimum(*child); - return true; + // Move to the child and increment depth. + ref = *child; + depth++; + continue; } - // We recurse into the child. - return LowerBound(*child, key, equal, depth + 1); - } - - // Push back all prefix bytes. - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - } - nodes.emplace(node, 0); - - // We compare the prefix bytes with the key bytes. - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - // We found a prefix byte that is less than its corresponding key byte. - // I.e., the subsequent node is lesser than the key. Thus, the next node - // is the lower bound. - if (prefix.data[i] < key[depth + i]) { - return Next(); + // Push back all prefix bytes. + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + current_key.Push(prefix.data[i]); } + nodes.emplace(ref.get(), 0); - // We found a prefix byte that is greater than its corresponding key byte. - // I.e., the subsequent node is greater than the key. Thus, the minimum is - // the lower bound. - if (prefix.data[i] > key[depth + i]) { - FindMinimum(*prefix.ptr); - return true; + // We compare the prefix bytes with the key bytes. + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + // We found a prefix byte that is less than its corresponding key byte. + // I.e., the subsequent node is lesser than the key. Thus, the next node + // is the lower bound. + if (prefix.data[i] < key[depth + i]) { + return Next(); + } + + // We found a prefix byte that is greater than its corresponding key byte. + // I.e., the subsequent node is greater than the key. Thus, the minimum is + // the lower bound. + if (prefix.data[i] > key[depth + i]) { + FindMinimum(*prefix.ptr); + return true; + } } - } - // The prefix matches the key. We recurse into the child. - depth += prefix.data[Prefix::Count(art)]; - return LowerBound(*prefix.ptr, key, equal, depth); + // The prefix matches the key. Move to the child and update depth. + depth += prefix.data[Prefix::Count(art)]; + ref = *prefix.ptr; + } + // Should always have a node with metadata. + throw InternalException("ART Iterator::LowerBound: Reached node without metadata"); } bool Iterator::Next() { diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp index f6c7751d6..3f1190216 100644 --- a/src/duckdb/src/execution/index/art/leaf.cpp +++ b/src/duckdb/src/execution/index/art/leaf.cpp @@ -162,7 +162,6 @@ bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, reference ref(node); while (ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref, LEAF); if (row_ids.size() + leaf.count > max_count) { return false; @@ -191,25 +190,44 @@ void Leaf::DeprecatedVacuum(ART &art, Node &node) { } } -string Leaf::DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify) { - D_ASSERT(node.GetType() == LEAF); - +string Leaf::DeprecatedToString(ART &art, const Node &node, const ToStringOptions &options) { + auto indent = [](string &str, const idx_t n) { + str.append(n, ' '); + }; string str = ""; + + if (!options.print_deprecated_leaves) { + indent(str, options.indent_level); + str += "[deprecated leaves]\n"; + return str; + } + reference ref(node); while (ref.get().HasMetadata()) { auto &leaf = Node::Ref(art, ref, LEAF); - D_ASSERT(leaf.count <= LEAF_SIZE); - + indent(str, options.indent_level); str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; for (uint8_t i = 0; i < leaf.count; i++) { str += to_string(leaf.row_ids[i]) + "-"; } - str += "] "; + str += "]\n"; ref = leaf.ptr; } - return only_verify ? "" : str; + return str; +} + +void Leaf::DeprecatedVerify(ART &art, const Node &node) { + D_ASSERT(node.GetType() == LEAF); + + reference ref(node); + + while (ref.get().HasMetadata()) { + auto &leaf = Node::Ref(art, ref, LEAF); + D_ASSERT(leaf.count <= LEAF_SIZE); + ref = leaf.ptr; + } } void Leaf::DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const { diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp index 478f18166..5cfef080f 100644 --- a/src/duckdb/src/execution/index/art/node.cpp +++ b/src/duckdb/src/execution/index/art/node.cpp @@ -391,44 +391,29 @@ void Node::TransformToDeprecated(ART &art, Node &node, // Verification //===--------------------------------------------------------------------===// -string Node::VerifyAndToString(ART &art, const bool only_verify) const { +void Node::Verify(ART &art) const { D_ASSERT(HasMetadata()); auto type = GetType(); switch (type) { case NType::LEAF_INLINED: - return only_verify ? "" : "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]"; + return; case NType::LEAF: - return Leaf::DeprecatedVerifyAndToString(art, *this, only_verify); + Leaf::DeprecatedVerify(art, *this); + return; case NType::PREFIX: { - auto str = Prefix::VerifyAndToString(art, *this, only_verify); - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str; + Prefix::Verify(art, *this); + return; } default: break; } - string str = "Node" + to_string(GetCapacity(type)) + ": [ "; - uint8_t byte = 0; - - if (IsLeafNode()) { - str = "Leaf " + str; - auto has_byte = GetNextByte(art, byte); - while (has_byte) { - str += to_string(byte) + "-"; - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - has_byte = GetNextByte(art, byte); - } - } else { + if (!IsLeafNode()) { + uint8_t byte = 0; auto child = GetNextChild(art, byte); while (child) { - str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; + child->Verify(art); if (byte == NumericLimits::Maximum()) { break; } @@ -436,11 +421,6 @@ string Node::VerifyAndToString(ART &art, const bool only_verify) const { child = GetNextChild(art, byte); } } - - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str + "]"; } void Node::VerifyAllocations(ART &art, unordered_map &node_counts) const { @@ -482,4 +462,118 @@ void Node::VerifyAllocations(ART &art, unordered_map &node_count scanner.Scan(handler); } +//===--------------------------------------------------------------------===// +// Printing +//===--------------------------------------------------------------------===// + +string Node::ToString(ART &art, const ToStringOptions &options) const { + auto indent = [](string &str, const idx_t n) { + str.append(n, ' '); + }; + // if inside gate, print byte values not ascii. + auto format_byte = [&](uint8_t byte) { + if (!options.inside_gate && options.display_ascii && byte >= 32 && byte <= 126) { + return string(1, static_cast(byte)); + } + return to_string(byte); + }; + auto type = GetType(); + bool is_gate = GetGateStatus() == GateStatus::GATE_SET; + bool propagate_gate = options.inside_gate || is_gate; + + bool print_full_tree = propagate_gate || !options.key_path || options.depth_remaining == 0; + + switch (type) { + case NType::LEAF_INLINED: { + string str = ""; + indent(str, options.indent_level); + return str + "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]\n"; + } + case NType::LEAF: { + ToStringOptions leaf_options = options; + return Leaf::DeprecatedToString(art, *this, leaf_options); + } + case NType::PREFIX: { + ToStringOptions prefix_options = options; + prefix_options.inside_gate = propagate_gate; + string str = Prefix::ToString(art, *this, prefix_options); + if (is_gate) { + string s = ""; + indent(s, options.indent_level); + s += "Gate\n"; + return s + str; + } + string s = ""; + return s + str; + } + default: + break; + } + string str = ""; + indent(str, options.indent_level); + str = str + "Node" + to_string(GetCapacity(type)) += "\n"; + uint8_t byte = 0; + + if (IsLeafNode()) { + indent(str, options.indent_level); + str += "Leaf |"; + auto has_byte = GetNextByte(art, byte); + while (has_byte) { + str += format_byte(byte) + "|"; + if (byte == NumericLimits::Maximum()) { + break; + } + byte++; + has_byte = GetNextByte(art, byte); + } + str += "\n"; + } else { + uint8_t expected_byte = 0; + bool has_expected_byte = false; + if (options.key_path && !print_full_tree && options.key_depth < options.key_path->len) { + expected_byte = (*options.key_path)[options.key_depth]; + has_expected_byte = true; + } + + uint8_t byte = 0; + auto child = GetNextChild(art, byte); + while (child) { + // Determine if this child is on the path to the key_path + // If we have an expected byte, only traverse the matching child + // If we don't have an expected byte, we're printing the full tree, so all children are on_path. + bool on_path = !has_expected_byte || (has_expected_byte && byte == expected_byte); + if (on_path) { + ToStringOptions child_options = options; + child_options.indent_level = options.indent_level + options.indent_amount; + child_options.inside_gate = propagate_gate; + child_options.key_depth = has_expected_byte ? options.key_depth + 1 : options.key_depth; + child_options.depth_remaining = (options.depth_remaining > 0) ? options.depth_remaining - 1 : 0; + string c = child->ToString(art, child_options); + indent(str, options.indent_level); + str = str + format_byte(byte) + ",\n" + c; + } else { + // If we have an expected byte, but the current byte is not the expected byte. + // In this case we check if we are only printing the structure, in which case we skip printing the + // child byte. + if (!options.structure_only) { + indent(str, options.indent_level); + str = str + format_byte(byte) + ", [not printed]\n"; + } + } + byte++; + child = GetNextChild(art, byte); + if (byte == NumericLimits::Maximum()) { + break; + } + } + } + + if (is_gate) { + string s = ""; + indent(s, options.indent_level + options.indent_amount); + str = "Gate\n" + s + str; + } + return str; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp index 00e94967a..8eaef3968 100644 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -65,8 +65,8 @@ void Prefix::New(ART &art, reference &ref, const ARTKey &key, const idx_t } } -void Prefix::Concat(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte, - const GateStatus node4_status) { +void Prefix::Concat(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte, const GateStatus node4_status, + const GateStatus status) { // We have four situations from which we enter here: // 1: PREFIX (parent) - Node4 (prev_node4) - PREFIX (child) - INLINED_LEAF, or // 2: PREFIX (parent) - Node4 (prev_node4) - INLINED_LEAF (child), or @@ -90,16 +90,17 @@ void Prefix::Concat(ART &art, Node &parent, Node &node4, const Node child, uint8 ConcatChildIsGate(art, parent, node4, child, byte); return; } - - auto inside_gate = parent.GetGateStatus() == GateStatus::GATE_SET; - ConcatInternal(art, parent, node4, child, byte, inside_gate); - return; + ConcatInternal(art, parent, node4, child, byte, status); } void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { D_ASSERT(node.HasMetadata()); D_ASSERT(pos < Count(art)); + // We always reduce by at least one byte, + // thus, if the prefix was a gate, it no longer is. + node.SetGateStatus(GateStatus::GATE_NOT_SET); + Prefix prefix(art, node); if (pos == idx_t(prefix.data[Count(art)] - 1)) { auto next = *prefix.ptr; @@ -182,23 +183,43 @@ GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uin return GateStatus::GATE_NOT_SET; } -string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { +string Prefix::ToString(ART &art, const Node &node, const ToStringOptions &options) { + auto indent = [](string &str, const idx_t n) { + str.append(n, ' '); + }; + auto format_byte = [&](uint8_t byte) { + if (!options.inside_gate && options.display_ascii && byte >= 32 && byte <= 126) { + return string(1, static_cast(byte)); + } + return to_string(byte); + }; string str = ""; + indent(str, options.indent_level); + reference ref(node); + ToStringOptions child_options = options; + Iterator(art, ref, true, false, [&](const Prefix &prefix) { + str += "Prefix: |"; + idx_t prefix_len = prefix.data[Count(art)]; + for (idx_t i = 0; i < prefix_len; i++) { + str += format_byte(prefix.data[i]) + "|"; + if (options.key_path) { + child_options.key_depth++; + } + } + }); + string child = ref.get().ToString(art, child_options); + return str + "\n" + child; +} + +void Prefix::Verify(ART &art, const Node &node) { reference ref(node); Iterator(art, ref, true, false, [&](Prefix &prefix) { D_ASSERT(prefix.data[Count(art)] != 0); D_ASSERT(prefix.data[Count(art)] <= Count(art)); - - str += " Prefix :[ "; - for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { - str += to_string(prefix.data[i]) + "-"; - } - str += " ] "; }); - auto child = ref.get().VerifyAndToString(art, only_verify); - return only_verify ? "" : str + child; + ref.get().Verify(art); } void Prefix::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { @@ -282,9 +303,9 @@ Prefix Prefix::GetTail(ART &art, const Node &node) { } void Prefix::ConcatInternal(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte, - const bool inside_gate) { + const GateStatus status) { if (child.GetType() == NType::LEAF_INLINED) { - if (inside_gate) { + if (status == GateStatus::GATE_SET) { if (parent.GetType() == NType::PREFIX) { // The parent only contained the Node4, so we can now inline 'all the way up', // and the gate is no longer nested. diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp index 2c0d43d91..c60886c31 100644 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ b/src/duckdb/src/execution/index/bound_index.cpp @@ -1,11 +1,13 @@ #include "duckdb/execution/index/bound_index.hpp" +#include "duckdb/common/array.hpp" #include "duckdb/common/radix.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/storage/table/append_state.hpp" +#include "duckdb/common/types/selection_vector.hpp" namespace duckdb { @@ -18,7 +20,6 @@ BoundIndex::BoundIndex(const string &name, const string &index_type, IndexConstr const vector> &unbound_expressions_p, AttachedDatabase &db) : Index(column_ids, table_io_manager, db), name(name), index_type(index_type), index_constraint_type(index_constraint_type) { - for (auto &expr : unbound_expressions_p) { types.push_back(expr->return_type.InternalType()); logical_types.push_back(expr->return_type); @@ -79,10 +80,16 @@ bool BoundIndex::MergeIndexes(BoundIndex &other_index) { return MergeIndexes(state, other_index); } -string BoundIndex::VerifyAndToString(const bool only_verify) { +void BoundIndex::Verify() { IndexLock l; InitializeLock(l); - return VerifyAndToString(l, only_verify); + Verify(l); +} + +string BoundIndex::ToString(bool display_ascii) { + IndexLock l; + InitializeLock(l); + return ToString(l, display_ascii); } void BoundIndex::VerifyAllocations() { @@ -135,6 +142,15 @@ bool BoundIndex::IndexIsUpdated(const vector &column_ids_p) const return false; } +bool BoundIndex::RequiresTransactionality() const { + return false; +} + +unique_ptr BoundIndex::CreateEmptyCopy(const string &name_prefix, + IndexConstraintType constraint_type) const { + throw InternalException("BoundIndex::CreateEmptyCopy is not supported for this index type"); +} + IndexStorageInfo BoundIndex::SerializeToDisk(QueryContext context, const case_insensitive_map_t &options) { throw NotImplementedException("The implementation of this index disk serialization does not exist."); } @@ -154,28 +170,80 @@ string BoundIndex::AppendRowError(DataChunk &input, idx_t index) { return error; } -void BoundIndex::ApplyBufferedAppends(const vector &table_types, ColumnDataCollection &buffered_appends, - const vector &mapped_column_ids) { - IndexAppendInfo index_append_info(IndexAppendMode::INSERT_DUPLICATES, nullptr); +namespace { - ColumnDataScanState state; - buffered_appends.InitializeScan(state); +struct BufferedReplayState { + optional_ptr buffer = nullptr; + ColumnDataScanState scan_state; + DataChunk current_chunk; + bool scan_initialized = false; +}; +} // namespace + +void BoundIndex::ApplyBufferedReplays(const vector &table_types, BufferedIndexReplays &buffered_replays, + const vector &mapped_column_ids) { + if (!buffered_replays.HasBufferedReplays()) { + return; + } - DataChunk scan_chunk; - buffered_appends.InitializeScanChunk(scan_chunk); + // We have two replay states: one for inserts and one for deletes. These are indexed into using the + // replay_type. Both scans are interleaved, so the state maintains the position of each scan. + array replay_states; DataChunk table_chunk; table_chunk.InitializeEmpty(table_types); - while (buffered_appends.Scan(state, scan_chunk)) { - for (idx_t i = 0; i < scan_chunk.ColumnCount() - 1; i++) { - auto col_id = mapped_column_ids[i].GetPrimaryIndex(); - table_chunk.data[col_id].Reference(scan_chunk.data[i]); + for (const auto &replay_range : buffered_replays.ranges) { + const auto type_idx = static_cast(replay_range.type); + auto &state = replay_states[type_idx]; + + // Initialize the scan state if necessary. Take ownership of buffered operations, since we won't need + // them after replaying anyways. + if (!state.scan_initialized) { + state.buffer = buffered_replays.GetBuffer(replay_range.type); + state.buffer->InitializeScan(state.scan_state); + state.buffer->InitializeScanChunk(state.current_chunk); + state.scan_initialized = true; } - table_chunk.SetCardinality(scan_chunk.size()); - auto error = Append(table_chunk, scan_chunk.data.back(), index_append_info); - if (error.HasError()) { - throw InternalException("error while applying buffered appends: " + error.Message()); + idx_t current_row = replay_range.start; + while (current_row < replay_range.end) { + // Scan the next DataChunk from the ColumnDataCollection buffer if the current row is on or after + // that chunk's starting row index. + if (current_row >= state.scan_state.next_row_index) { + if (!state.buffer->Scan(state.scan_state, state.current_chunk)) { + throw InternalException("Buffered index data exhausted during replay"); + } + } + + // We need to process the remaining rows in the current chunk, which is the minimum of the available + // rows in the chunk and the remaining rows in the current range. + const auto offset_in_chunk = current_row - state.scan_state.current_row_index; + const auto available_in_chunk = state.current_chunk.size() - offset_in_chunk; + // [start, end) in ReplayRange is [inclusive, exclusive). + const auto range_remaining = replay_range.end - current_row; + const auto rows_to_process = MinValue(available_in_chunk, range_remaining); + + SelectionVector sel(offset_in_chunk, rows_to_process); + + for (idx_t col_idx = 0; col_idx < state.current_chunk.ColumnCount() - 1; col_idx++) { + const auto col_id = mapped_column_ids[col_idx].GetPrimaryIndex(); + table_chunk.data[col_id].Reference(state.current_chunk.data[col_idx]); + table_chunk.data[col_id].Slice(sel, rows_to_process); + } + table_chunk.SetCardinality(rows_to_process); + Vector row_ids(state.current_chunk.data.back(), sel, rows_to_process); + + if (replay_range.type == BufferedIndexReplay::INSERT_ENTRY) { + IndexAppendInfo append_info(IndexAppendMode::INSERT_DUPLICATES, nullptr); + const auto error = Append(table_chunk, row_ids, append_info); + if (error.HasError()) { + throw InternalException("error while applying buffered appends: " + error.Message()); + } + current_row += rows_to_process; + continue; + } + Delete(table_chunk, row_ids); + current_row += rows_to_process; } } } diff --git a/src/duckdb/src/execution/index/fixed_size_allocator.cpp b/src/duckdb/src/execution/index/fixed_size_allocator.cpp index dd4758bb9..3b1572d75 100644 --- a/src/duckdb/src/execution/index/fixed_size_allocator.cpp +++ b/src/duckdb/src/execution/index/fixed_size_allocator.cpp @@ -4,10 +4,9 @@ namespace duckdb { -FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) - : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), segment_size(segment_size), - total_segment_count(0) { - +FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, MemoryTag memory_tag) + : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), memory_tag(memory_tag), + segment_size(segment_size), total_segment_count(0) { if (segment_size > block_manager.GetBlockSize() - sizeof(validity_t)) { throw InternalException("The maximum segment size of fixed-size allocators is " + to_string(block_manager.GetBlockSize() - sizeof(validity_t))); @@ -48,7 +47,7 @@ IndexPointer FixedSizeAllocator::New() { if (!buffer_with_free_space.IsValid()) { // Add a new buffer. auto buffer_id = GetAvailableBufferId(); - buffers[buffer_id] = make_uniq(block_manager); + buffers[buffer_id] = make_uniq(block_manager, memory_tag); buffers_with_free_space.insert(buffer_id); buffer_with_free_space = buffer_id; @@ -321,7 +320,6 @@ void FixedSizeAllocator::Init(const FixedSizeAllocatorInfo &info) { total_segment_count = 0; for (idx_t i = 0; i < info.buffer_ids.size(); i++) { - // read all FixedSizeBuffer data auto buffer_id = info.buffer_ids[i]; diff --git a/src/duckdb/src/execution/index/fixed_size_buffer.cpp b/src/duckdb/src/execution/index/fixed_size_buffer.cpp index 82bbccac2..1cf36c1a5 100644 --- a/src/duckdb/src/execution/index/fixed_size_buffer.cpp +++ b/src/duckdb/src/execution/index/fixed_size_buffer.cpp @@ -35,12 +35,11 @@ void PartialBlockForIndex::Clear() { constexpr idx_t FixedSizeBuffer::BASE[]; constexpr uint8_t FixedSizeBuffer::SHIFT[]; -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) +FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag) : block_manager(block_manager), readers(0), segment_count(0), allocation_size(0), dirty(false), vacuum(false), loaded(false), block_pointer(), block_handle(nullptr) { - auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, &block_manager, false); + buffer_handle = buffer_manager.Allocate(memory_tag, &block_manager, false); block_handle = buffer_handle.GetBlockHandle(); // Zero-initialize the buffer as it might get serialized to storage. @@ -52,7 +51,6 @@ FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segmen const BlockPointer &block_pointer) : block_manager(block_manager), readers(0), segment_count(segment_count), allocation_size(allocation_size), dirty(false), vacuum(false), loaded(false), block_pointer(block_pointer) { - D_ASSERT(block_pointer.IsValid()); block_handle = block_manager.RegisterBlock(block_pointer.block_id); D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); @@ -159,7 +157,6 @@ void FixedSizeBuffer::LoadFromDisk() { } uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count, const idx_t available_segments) { - // Get a handle to the buffer's validity mask (offset 0). SegmentHandle handle(*this, 0); const auto bitmask_ptr = handle.GetPtr(); diff --git a/src/duckdb/src/execution/index/index_type_set.cpp b/src/duckdb/src/execution/index/index_type_set.cpp index 4fe7cda4f..0422f8a02 100644 --- a/src/duckdb/src/execution/index/index_type_set.cpp +++ b/src/duckdb/src/execution/index/index_type_set.cpp @@ -5,7 +5,6 @@ namespace duckdb { IndexTypeSet::IndexTypeSet() { - // Register the ART index type by default IndexType art_index_type; art_index_type.name = ART::TYPE_NAME; diff --git a/src/duckdb/src/execution/index/unbound_index.cpp b/src/duckdb/src/execution/index/unbound_index.cpp index 0d117ca92..f15d6bd15 100644 --- a/src/duckdb/src/execution/index/unbound_index.cpp +++ b/src/duckdb/src/execution/index/unbound_index.cpp @@ -12,7 +12,6 @@ UnboundIndex::UnboundIndex(unique_ptr create_info, IndexStorageInfo TableIOManager &table_io_manager, AttachedDatabase &db) : Index(create_info->Cast().column_ids, table_io_manager, db), create_info(std::move(create_info)), storage_info(std::move(storage_info_p)) { - // Memory safety check. for (idx_t info_idx = 0; info_idx < storage_info.allocator_infos.size(); info_idx++) { auto &info = storage_info.allocator_infos[info_idx]; @@ -35,26 +34,48 @@ void UnboundIndex::CommitDrop() { } } -void UnboundIndex::BufferChunk(DataChunk &chunk, Vector &row_ids, const vector &mapped_column_ids_p) { +void UnboundIndex::BufferChunk(DataChunk &index_column_chunk, Vector &row_ids, + const vector &mapped_column_ids_p, const BufferedIndexReplay replay_type) { D_ASSERT(!column_ids.empty()); - auto types = chunk.GetTypes(); + auto types = index_column_chunk.GetTypes(); // column types types.push_back(LogicalType::ROW_TYPE); - if (!buffered_appends) { - auto &allocator = Allocator::Get(db); - buffered_appends = make_uniq(allocator, types); + auto &allocator = Allocator::Get(db); + + //! First time we are buffering data, canonical column_id mapping is stored. + //! This should be a sorted list of all the physical offsets of Indexed columns on this table. + if (mapped_column_ids.empty()) { mapped_column_ids = mapped_column_ids_p; } D_ASSERT(mapped_column_ids == mapped_column_ids_p); + // combined_chunk has all the indexed columns according to mapped_column_ids ordering, as well as a rowid column. DataChunk combined_chunk; combined_chunk.InitializeEmpty(types); - for (idx_t i = 0; i < chunk.ColumnCount(); i++) { - combined_chunk.data[i].Reference(chunk.data[i]); + for (idx_t i = 0; i < index_column_chunk.ColumnCount(); i++) { + combined_chunk.data[i].Reference(index_column_chunk.data[i]); } combined_chunk.data.back().Reference(row_ids); - combined_chunk.SetCardinality(chunk.size()); - buffered_appends->Append(combined_chunk); + combined_chunk.SetCardinality(index_column_chunk.size()); + + auto &buffer = buffered_replays.GetBuffer(replay_type); + if (buffer == nullptr) { + buffer = make_uniq(allocator, types); + } + // The starting index of the buffer range is the size of the buffer. + const idx_t start = buffer->Count(); + const idx_t end = start + combined_chunk.size(); + auto &ranges = buffered_replays.ranges; + + if (ranges.empty() || ranges.back().type != replay_type) { + // If there are no buffered ranges, or the replay types don't match, append a new range. + ranges.emplace_back(replay_type, start, end); + buffer->Append(combined_chunk); + return; + } + // Otherwise merge the range with the previous one. + ranges.back().end = end; + buffer->Append(combined_chunk); } } // namespace duckdb diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index f991ead7e..ef2fe68ac 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -45,7 +45,6 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o auto type = condition.left->return_type; if (condition.comparison == ExpressionType::COMPARE_EQUAL || condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - // ensure that all equality conditions are at the front, // and that all other conditions are at the back D_ASSERT(equality_types.size() == condition_types.size()); @@ -82,7 +81,6 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o // Initialize the row matcher that are used for filtering during the probing only if there are non-equality if (!non_equality_predicates.empty()) { - row_matcher_probe = unique_ptr(new RowMatcher()); row_matcher_probe_no_match_sel = unique_ptr(new RowMatcher()); @@ -103,9 +101,9 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o pointer_offset = offsets.back(); entry_size = layout_ptr->GetRowWidth(); - data_collection = make_uniq(buffer_manager, layout_ptr); - sink_collection = - make_uniq(buffer_manager, layout_ptr, radix_bits, layout_ptr->ColumnCount() - 1); + data_collection = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE); + sink_collection = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, + radix_bits, layout_ptr->ColumnCount() - 1); dead_end = make_unsafe_uniq_array_uninitialized(layout_ptr->GetRowWidth()); memset(dead_end.get(), 0, layout_ptr->GetRowWidth()); @@ -172,7 +170,6 @@ idx_t GetOptionalIndex(const SelectionVector *sel, const idx_t idx) { static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry_t &entry, Vector &pointers_result_v, idx_t row_ht_offset, idx_t &keys_to_compare_count, const idx_t &row_index) { - const auto row_ptr_insert_to = FlatVector::GetData(pointers_result_v); const auto ht_offsets_and_salts = FlatVector::GetData(state.ht_offsets_and_salts_v); @@ -189,13 +186,11 @@ static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry template static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, Vector &pointers_result_v, const SelectionVector *row_sel, idx_t &count) { - auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); idx_t keys_to_compare_count = 0; for (idx_t i = 0; i < count; i++) { - auto row_hash = hashes_dense[i]; // hashes have been flattened before -> always access dense auto row_ht_offset = row_hash & ht.bitmask; @@ -260,7 +255,6 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta Vector &hashes_v, const SelectionVector *row_sel, idx_t &count, JoinHashTable &ht, ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel, bool has_row_sel) { - // densify hashes: If there is no sel, flatten the hashes, else densify via UnifiedVectorFormat if (has_row_sel) { UnifiedVectorFormat hashes_unified_v; @@ -339,7 +333,6 @@ inline bool JoinHashTable::UseSalt() const { void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, const SelectionVector *sel, idx_t &count, Vector &pointers_result_v, SelectionVector &match_sel, const bool has_sel) { - if (UseSalt()) { GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, *this, entries, pointers_result_v, match_sel, has_sel); @@ -722,6 +715,10 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations void JoinHashTable::InsertHashes(Vector &hashes_v, const idx_t count, TupleDataChunkState &chunk_state, InsertState &insert_state, bool parallel) { + // Insert Hashes into the BF + if (bloom_filter.IsInitialized()) { + bloom_filter.InsertHashes(hashes_v, count); + } auto atomic_entries = reinterpret_cast *>(this->entries); auto row_locations = chunk_state.row_locations; if (parallel) { @@ -740,6 +737,10 @@ void JoinHashTable::AllocatePointerTable() { throw InternalException("Hashtable capacity exceeds 48-bit limit (2^48 - 1)"); } + if (should_build_bloom_filter) { + bloom_filter.Initialize(context, Count()); + } + if (hash_map.get()) { // There is already a hash map auto current_capacity = hash_map.GetSize() / sizeof(ht_entry_t); @@ -888,13 +889,13 @@ bool ScanStructure::PointersExhausted() const { } idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { - // Initialize the found_match array to the current sel_vector for (idx_t i = 0; i < this->count; ++i) { match_sel.set_index(i, this->sel_vector.get_index(i)); } // If there is a matcher for the probing side because of non-equality predicates, use it + idx_t result_count; if (ht.needs_chain_matcher) { idx_t no_match_count = 0; auto &matcher = no_match_sel ? ht.row_matcher_probe_no_match_sel : ht.row_matcher_probe; @@ -902,12 +903,17 @@ idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_s // we need to only use the vectors with the indices of the columns that are used in the probe phase, namely // the non-equality columns - return matcher->Match(keys, key_state.vector_data, match_sel, this->count, pointers, no_match_sel, - no_match_count); + result_count = + matcher->Match(keys, key_state.vector_data, match_sel, this->count, pointers, no_match_sel, no_match_count); } else { // no match sel is the opposite of match sel - return this->count; + result_count = this->count; } + + // Update total probe match count + ht.total_probe_matches.fetch_add(result_count, std::memory_order_relaxed); + + return result_count; } idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { @@ -934,7 +940,6 @@ idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vect } void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_count) { - if (!ht.chains_longer_than_one) { this->count = 0; return; @@ -1455,6 +1460,23 @@ idx_t JoinHashTable::FillWithHTOffsets(JoinHTScanState &state, Vector &addresses return key_count; } +idx_t JoinHashTable::ScanKeyColumn(Vector &addresses, Vector &result, idx_t column_index) const { + // nothing to scan if the build side is empty + if (data_collection->ChunkCount() == 0) { + return 0; + } + D_ASSERT(result.GetType() == layout_ptr->GetTypes()[column_index]); + JoinHTScanState join_ht_state(*data_collection, 0, data_collection->ChunkCount(), + TupleDataPinProperties::KEEP_EVERYTHING_PINNED); + auto key_count = FillWithHTOffsets(join_ht_state, addresses); + if (key_count == 0) { + return 0; + } + const auto &sel = *FlatVector::IncrementalSelectionVector(); + data_collection->Gather(addresses, sel, key_count, column_index, result, sel, nullptr); + return key_count; +} + idx_t JoinHashTable::GetTotalSize(const vector &partition_sizes, const vector &partition_counts, idx_t &max_partition_size, idx_t &max_partition_count) const { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); @@ -1536,8 +1558,8 @@ void JoinHashTable::SetRepartitionRadixBits(const idx_t max_ht_size, const idx_t } } radix_bits += added_bits; - sink_collection = - make_uniq(buffer_manager, layout_ptr, radix_bits, layout_ptr->ColumnCount() - 1); + sink_collection = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, + radix_bits, layout_ptr->ColumnCount() - 1); // Need to initialize again after changing the number of bits InitializePartitionMasks(); @@ -1567,8 +1589,8 @@ idx_t JoinHashTable::FinishedPartitionCount() const { } void JoinHashTable::Repartition(JoinHashTable &global_ht) { - auto new_sink_collection = make_uniq(buffer_manager, layout_ptr, global_ht.radix_bits, - layout_ptr->ColumnCount() - 1); + auto new_sink_collection = make_uniq( + buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, global_ht.radix_bits, layout_ptr->ColumnCount() - 1); sink_collection->Repartition(context, *new_sink_collection); sink_collection = std::move(new_sink_collection); global_ht.Merge(*this); diff --git a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp index 53fe0368a..df798cfa0 100644 --- a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp +++ b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp @@ -16,13 +16,13 @@ AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_ AggregateObject::AggregateObject(BoundAggregateExpression *aggr) : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), - AlignValue(aggr->function.state_size(aggr->function)), aggr->aggr_type, + AlignValue(aggr->function.GetStateSizeCallback()(aggr->function)), aggr->aggr_type, aggr->return_type.InternalType(), aggr->filter.get()) { } AggregateObject::AggregateObject(const BoundWindowExpression &window) : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), - AlignValue(window.aggregate->state_size(*window.aggregate)), + AlignValue(window.aggregate->GetStateSizeCallback()(*window.aggregate)), window.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT, window.return_type.InternalType(), window.filter_expr.get()) { } diff --git a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp index dc37353f7..0e7910cbd 100644 --- a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp +++ b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp @@ -29,7 +29,6 @@ DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector> grou filter_count++; payload_types_filters.push_back(aggr.filter->return_type); } - if (!aggr.function.combine) { + if (!aggr.function.HasStateCombineCallback()) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } aggregates.push_back(std::move(expr)); @@ -63,7 +63,7 @@ void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggr filter_count++; } } - if (!aggr.function.combine) { + if (!aggr.function.HasStateCombineCallback()) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp index 76e1444e0..4eda5cba8 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -221,7 +221,6 @@ class HashAggregateGlobalSinkState : public GlobalSinkState { class HashAggregateLocalSinkState : public LocalSinkState { public: HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) { - auto &payload_types = op.grouped_aggregate_data.payload_types; if (!payload_types.empty()) { aggregate_input_chunk.InitializeEmpty(payload_types); @@ -415,7 +414,6 @@ SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk // Combine //===--------------------------------------------------------------------===// void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &global_sink = input.global_state.Cast(); auto &sink = input.local_state.Cast(); @@ -860,8 +858,8 @@ unique_ptr PhysicalHashAggregate::GetLocalSourceState(Executio return make_uniq(context, *this); } -SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalHashAggregate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &sink_gstate = sink_state->Cast(); auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp index 92233b2db..f388dd25c 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp @@ -189,8 +189,8 @@ unique_ptr PhysicalPartitionedAggregate::GetGlobalSourceState return make_uniq(gstate); } -SourceResultType PhysicalPartitionedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalPartitionedAggregate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); auto &gsource = input.global_state.Cast(); gstate.aggregate_result.Scan(gsource.scan_state, chunk); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp index 9af2cf8ee..7a08b43bd 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -37,7 +37,7 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(PhysicalPlan &physica bindings.push_back(&aggr); D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.combine); + D_ASSERT(aggr.function.HasStateCombineCallback()); for (auto &child : aggr.children) { payload_types.push_back(child->return_type); } @@ -188,8 +188,8 @@ unique_ptr PhysicalPerfectHashAggregate::GetGlobalSourceState return make_uniq(); } -SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalPerfectHashAggregate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &state = input.global_state.Cast(); auto &gstate = sink_state->Cast(); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp index 38ef17061..5f157b201 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -28,16 +28,16 @@ class StreamingWindowState : public OperatorState { public: struct AggregateState { AggregateState(ClientContext &client, BoundWindowExpression &wexpr, Allocator &allocator) - : wexpr(wexpr), arena_allocator(Allocator::DefaultAllocator()), executor(client), filter_executor(client), + : wexpr(wexpr), arena_allocator(BufferAllocator::Get((client))), executor(client), filter_executor(client), statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)), hashes(LogicalType::HASH), addresses(LogicalType::POINTER) { D_ASSERT(wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE); auto &aggregate = *wexpr.aggregate; bind_data = wexpr.bind_info.get(); - dtor = aggregate.destructor; - state.resize(aggregate.state_size(aggregate)); + dtor = aggregate.GetStateDestructorCallback(); + state.resize(aggregate.GetStateSizeCallback()(aggregate)); state_ptr = state.data(); - aggregate.initialize(aggregate, state.data()); + aggregate.GetStateInitCallback()(aggregate, state.data()); for (auto &child : wexpr.children) { arg_types.push_back(child->return_type); executor.AddExpression(*child); @@ -350,7 +350,7 @@ bool PhysicalStreamingWindow::IsStreamingFunction(ClientContext &context, unique // TODO: add more expression types here? case ExpressionType::WINDOW_AGGREGATE: // Aggregates with destructors (e.g., quantile) are too slow to repeatedly update/finalize - if (wexpr.aggregate->destructor) { + if (wexpr.aggregate->HasStateDestructorCallback()) { return false; } // We can stream aggregates if they are "running totals" @@ -479,9 +479,10 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da arg_cursor.data[struct_idx].Slice(arg_chunk.data[struct_idx], sel, 1); } if (filter_mask.RowIsValid(i) && distinct_mask.RowIsValid(i)) { - aggregate.update(arg_cursor.data.data(), aggr_input_data, arg_cursor.ColumnCount(), statev, 1); + aggregate.GetStateUpdateCallback()(arg_cursor.data.data(), aggr_input_data, arg_cursor.ColumnCount(), + statev, 1); } - aggregate.finalize(statev, aggr_input_data, result, 1, i); + aggregate.GetStateFinalizeCallback()(statev, aggr_input_data, result, 1, i); } } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index a2a3da965..4861df3d9 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -28,7 +28,6 @@ PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(PhysicalPlan &physical_pl : PhysicalOperator(physical_plan, PhysicalOperatorType::UNGROUPED_AGGREGATE, std::move(types), estimated_cardinality), aggregates(std::move(expressions)) { - distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates); if (!distinct_collection_info) { return; @@ -46,11 +45,11 @@ UngroupedAggregateState::UngroupedAggregateState(const vectorGetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = aggregate->Cast(); - auto state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, state.get()); + auto state = make_unsafe_uniq_array_uninitialized(aggr.function.GetStateSizeCallback()(aggr.function)); + aggr.function.GetStateInitCallback()(aggr.function, state.get()); aggregate_data.push_back(std::move(state)); bind_data.push_back(aggr.bind_info.get()); - destructors.push_back(aggr.function.destructor); + destructors.push_back(aggr.function.GetStateDestructorCallback()); #ifdef DEBUG counts[i] = 0; #endif @@ -116,7 +115,11 @@ void GlobalUngroupedAggregateState::Combine(LocalUngroupedAggregateState &other) AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); + if (!aggregate.function.HasStateCombineCallback()) { + throw InternalException("Aggregate function " + aggregate.function.name + + " does not support combining of states"); + } + aggregate.function.GetStateCombineCallback()(source_state, dest_state, aggr_input_data, 1); #ifdef DEBUG state.counts[aggr_idx] += other.state.counts[aggr_idx]; #endif @@ -137,7 +140,11 @@ void GlobalUngroupedAggregateState::CombineDistinct(LocalUngroupedAggregateState Vector state_vec(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get()))); Vector combined_vec(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); - aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); + if (!aggregate.function.HasStateCombineCallback()) { + throw InternalException("Aggregate function " + aggregate.function.name + + " does not support combining of states"); + } + aggregate.function.GetStateCombineCallback()(state_vec, combined_vec, aggr_input_data, 1); #ifdef DEBUG state.counts[aggr_idx] += other.state.counts[aggr_idx]; #endif @@ -239,7 +246,6 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { public: void InitializeDistinctAggregates(const PhysicalUngroupedAggregate &op, const UngroupedAggregateGlobalSinkState &gstate, ExecutionContext &context) { - if (!op.distinct_data) { return; } @@ -270,7 +276,7 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { bool PhysicalUngroupedAggregate::SinkOrderDependent() const { for (auto &expr : aggregates) { auto &aggr = expr->Cast(); - if (aggr.function.order_dependent == AggregateOrderDependent::ORDER_DEPENDENT) { + if (aggr.function.GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT) { return true; } } @@ -354,8 +360,8 @@ void LocalUngroupedAggregateState::Sink(DataChunk &payload_chunk, idx_t payload_ D_ASSERT(payload_idx + payload_cnt <= payload_chunk.data.size()); auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; AggregateInputData aggr_input_data(state.bind_data[aggr_idx], allocator); - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, state.aggregate_data[aggr_idx].get(), - payload_chunk.size()); + aggregate.function.GetStateSimpleUpdateCallback()(start_of_input, aggr_input_data, payload_cnt, + state.aggregate_data[aggr_idx].get(), payload_chunk.size()); } //===--------------------------------------------------------------------===// @@ -628,7 +634,8 @@ void VerifyNullHandling(DataChunk &chunk, UngroupedAggregateState &state, #ifdef DEBUG for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { auto &aggr = aggregates[aggr_idx]->Cast(); - if (state.counts[aggr_idx] == 0 && aggr.function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + if (state.counts[aggr_idx] == 0 && + aggr.function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { // Default is when 0 values go in, NULL comes out UnifiedVectorFormat vdata; chunk.data[aggr_idx].ToUnifiedFormat(1, vdata); @@ -645,12 +652,13 @@ void GlobalUngroupedAggregateState::Finalize(DataChunk &result, idx_t column_off Vector state_vector(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); - aggregate.function.finalize(state_vector, aggr_input_data, result.data[column_offset + aggr_idx], 1, 0); + aggregate.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, + result.data[column_offset + aggr_idx], 1, 0); } } -SourceResultType PhysicalUngroupedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalUngroupedAggregate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); D_ASSERT(gstate.finished); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 102f491f0..5e23c8cd3 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -1,6 +1,6 @@ #include "duckdb/execution/operator/aggregate/physical_window.hpp" -#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sort_strategy.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" #include "duckdb/common/types/row/tuple_data_iterator.hpp" #include "duckdb/function/window/window_aggregate_function.hpp" @@ -17,7 +17,7 @@ namespace duckdb { // Global sink state class WindowGlobalSinkState; -enum WindowGroupStage : uint8_t { MASK, SINK, FINALIZE, GETDATA, DONE }; +enum WindowGroupStage : uint8_t { SORT, MATERIALIZE, MASK, SINK, FINALIZE, GETDATA, DONE }; struct WindowSourceTask { WindowSourceTask() { @@ -48,17 +48,28 @@ class WindowHashGroup { using Task = WindowSourceTask; using TaskPtr = optional_ptr; using ScannerPtr = unique_ptr; + using ChunkRow = SortStrategy::ChunkRow; - WindowHashGroup(WindowGlobalSinkState &gsink, HashGroupPtr &sorted, const idx_t hash_bin_p); + template + static T BinValue(T n, T val) { + return ((n + (val - 1)) / val); + } + + WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p); void AllocateMasks(); void ComputeMasks(const idx_t begin_idx, const idx_t end_idx); ExecutorGlobalStates &GetGlobalStates(ClientContext &client); + //! The number of chunks in the group + inline idx_t ChunkCount() const { + return blocks; + } + // The total number of tasks we will execute per thread inline idx_t GetTaskCount() const { - return GetThreadCount() * (uint8_t(WindowGroupStage::DONE) - uint8_t(WindowGroupStage::MASK)); + return GetThreadCount() * (uint8_t(WindowGroupStage::DONE) - uint8_t(WindowGroupStage::SORT)); } // The total number of threads we will use inline idx_t GetThreadCount() const { @@ -79,6 +90,18 @@ class WindowHashGroup { bool TryPrepareNextStage() { lock_guard prepare_guard(lock); switch (stage.load()) { + case WindowGroupStage::SORT: + if (sorted == blocks) { + stage = WindowGroupStage::MATERIALIZE; + return true; + } + return false; + case WindowGroupStage::MATERIALIZE: + if (materialized == blocks && rows.get()) { + stage = WindowGroupStage::MASK; + return true; + } + return false; case WindowGroupStage::MASK: if (masked == blocks) { stage = WindowGroupStage::SINK; @@ -118,7 +141,7 @@ class WindowHashGroup { task.thread_idx = next_task % group_threads; task.group_idx = hash_bin; task.begin_idx = task.thread_idx * per_thread; - task.max_idx = rows->ChunkCount(); + task.max_idx = ChunkCount(); task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); ++next_task; return true; @@ -130,13 +153,11 @@ class WindowHashGroup { //! The shared global state from sinking WindowGlobalSinkState &gsink; //! The hash partition data - HashGroupPtr hash_group; + HashGroupPtr rows; //! The size of the group idx_t count = 0; //! The number of blocks in the group idx_t blocks = 0; - unique_ptr rows; - TupleDataLayout layout; //! The partition boundary mask ValidityMask partition_mask; //! The order boundary mask @@ -160,6 +181,10 @@ class WindowHashGroup { idx_t group_threads = 0; //! The next task to process idx_t next_task = 0; + //! Count of sorted run blocks + std::atomic sorted; + //! Count of materialized run blocks + std::atomic materialized; //! Count of masked blocks std::atomic masked; //! Count of sunk rows @@ -180,8 +205,8 @@ class WindowGlobalSinkState : public GlobalSinkState { WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); SinkFinalizeType Finalize(ClientContext &client, InterruptState &interrupt_state) { - OperatorSinkFinalizeInput finalize {*hashed_sink, interrupt_state}; - auto result = global_partition->Finalize(client, finalize); + OperatorSinkFinalizeInput finalize {*strategy_sink, interrupt_state}; + auto result = sort_strategy->Finalize(client, finalize); return result; } @@ -191,9 +216,9 @@ class WindowGlobalSinkState : public GlobalSinkState { //! Client context ClientContext &client; //! The partitioned sunk data - unique_ptr global_partition; + unique_ptr sort_strategy; //! The partitioned sunk data - unique_ptr hashed_sink; + unique_ptr strategy_sink; //! The number of sunk rows (for progress) atomic count; //! The execution functions @@ -206,7 +231,7 @@ class WindowGlobalSinkState : public GlobalSinkState { class WindowLocalSinkState : public LocalSinkState { public: WindowLocalSinkState(ExecutionContext &context, const WindowGlobalSinkState &gstate) - : local_group(gstate.global_partition->GetLocalSinkState(context)) { + : local_group(gstate.sort_strategy->GetLocalSinkState(context)) { } unique_ptr local_group; @@ -218,7 +243,6 @@ PhysicalWindow::PhysicalWindow(PhysicalPlan &physical_plan, vector PhysicalOperatorType type) : PhysicalOperator(physical_plan, type, std::move(types), estimated_cardinality), select_list(std::move(select_list_p)), order_idx(0), is_order_dependent(false) { - idx_t max_orders = 0; for (idx_t i = 0; i < select_list.size(); ++i) { auto &expr = select_list[i]; @@ -271,7 +295,6 @@ static unique_ptr WindowExecutorFactory(BoundWindowExpression &w WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &client) : op(op), client(client), count(0) { - D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &wexpr = op.select_list[op.order_idx]->Cast(); @@ -283,9 +306,9 @@ WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientCon executors.emplace_back(std::move(wexec)); } - global_partition = make_uniq(client, wexpr.partitions, wexpr.orders, op.children[0].get().GetTypes(), - wexpr.partitions_stats, op.estimated_cardinality); - hashed_sink = global_partition->GetGlobalSinkState(client); + sort_strategy = SortStrategy::Factory(client, wexpr.partitions, wexpr.orders, op.children[0].get().GetTypes(), + wexpr.partitions_stats, op.estimated_cardinality); + strategy_sink = sort_strategy->GetGlobalSinkState(client); } //===--------------------------------------------------------------------===// @@ -296,16 +319,16 @@ SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, auto &lstate = sink.local_state.Cast(); gstate.count += chunk.size(); - OperatorSinkInput hsink {*gstate.hashed_sink, *lstate.local_group, sink.interrupt_state}; - return gstate.global_partition->Sink(context, chunk, hsink); + OperatorSinkInput hsink {*gstate.strategy_sink, *lstate.local_group, sink.interrupt_state}; + return gstate.sort_strategy->Sink(context, chunk, hsink); } SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { auto &gstate = combine.global_state.Cast(); auto &lstate = combine.local_state.Cast(); - OperatorSinkCombineInput hcombine {*gstate.hashed_sink, *lstate.local_group, combine.interrupt_state}; - return gstate.global_partition->Combine(context, hcombine); + OperatorSinkCombineInput hcombine {*gstate.strategy_sink, *lstate.local_group, combine.interrupt_state}; + return gstate.sort_strategy->Combine(context, hcombine); } unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { @@ -320,24 +343,17 @@ unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &cl SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gsink = input.global_state.Cast(); - auto &global_partition = *gsink.global_partition; - auto &hashed_sink = *gsink.hashed_sink; - - OperatorSinkFinalizeInput hfinalize {hashed_sink, input.interrupt_state}; - auto result = global_partition.Finalize(client, hfinalize); - - // Did we get any data? - if (result != SinkFinalizeType::READY) { - return result; - } + auto &sort_strategy = *gsink.sort_strategy; + auto &strategy_sink = *gsink.strategy_sink; - return global_partition.MaterializeHashGroups(pipeline, event, *this, hfinalize); + OperatorSinkFinalizeInput hfinalize {strategy_sink, input.interrupt_state}; + return sort_strategy.Finalize(client, hfinalize); } ProgressData PhysicalWindow::GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, const ProgressData source_progress) const { auto &gsink = gstate.Cast(); - return gsink.global_partition->GetSinkProgress(context, *gsink.hashed_sink, source_progress); + return gsink.sort_strategy->GetSinkProgress(context, *gsink.strategy_sink, source_progress); } //===--------------------------------------------------------------------===// @@ -367,6 +383,8 @@ class WindowGlobalSourceState : public GlobalSourceState { ClientContext &client; //! All the sunk data WindowGlobalSinkState &gsink; + //! The hashed sort global source state for delayed sorting + unique_ptr hashed_source; //! The sorted hash groups vector window_hash_groups; //! The total number of blocks to process; @@ -404,20 +422,18 @@ class WindowGlobalSourceState : public GlobalSourceState { WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &client, WindowGlobalSinkState &gsink_p) : client(client), gsink(gsink_p), next_group(0), locals(0), started(0), finished(0), stopped(false), completed(0) { - - auto &global_partition = *gsink.global_partition; - auto hashed_source = global_partition.GetGlobalSourceState(client, *gsink.hashed_sink); - auto &hash_groups = global_partition.GetHashGroups(*hashed_source); + auto &sort_strategy = *gsink.sort_strategy; + hashed_source = sort_strategy.GetGlobalSourceState(client, *gsink.strategy_sink); + auto &hash_groups = sort_strategy.GetHashGroups(*hashed_source); window_hash_groups.resize(hash_groups.size()); for (idx_t group_idx = 0; group_idx < hash_groups.size(); ++group_idx) { - auto rows = std::move(hash_groups[group_idx]); - if (!rows) { + const auto block_count = hash_groups[group_idx].chunks; + if (!block_count) { continue; } - auto window_hash_group = make_uniq(gsink, rows, group_idx); - const auto block_count = window_hash_group->rows->ChunkCount(); + auto window_hash_group = make_uniq(gsink, hash_groups[group_idx], group_idx); window_hash_group->batch_base = total_blocks; total_blocks += block_count; @@ -438,7 +454,7 @@ void WindowGlobalSourceState::CreateTaskList() { if (!window_hash_group) { continue; } - partition_blocks.emplace_back(window_hash_group->rows->ChunkCount(), group_idx); + partition_blocks.emplace_back(window_hash_group->blocks, group_idx); } std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); @@ -453,8 +469,8 @@ void WindowGlobalSourceState::CreateTaskList() { // STANDARD_VECTOR_SIZE >> ValidityMask::BITS_PER_VALUE, but if STANDARD_VECTOR_SIZE is say 2, // we need to align the chunk count to the mask width. const auto aligned_scale = MaxValue(ValidityMask::BITS_PER_VALUE / STANDARD_VECTOR_SIZE, 1); - const auto aligned_count = (max_block.first + aligned_scale - 1) / aligned_scale; - const auto per_thread = aligned_scale * ((aligned_count + threads - 1) / threads); + const auto aligned_count = WindowHashGroup::BinValue(max_block.first, aligned_scale); + const auto per_thread = aligned_scale * WindowHashGroup::BinValue(aligned_count, threads); if (!per_thread) { throw InternalException("No blocks per thread! %ld threads, %ld groups, %ld blocks, %ld hash group", threads, partition_blocks.size(), max_block.first, max_block.second); @@ -465,23 +481,14 @@ void WindowGlobalSourceState::CreateTaskList() { } } -WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, HashGroupPtr &sorted, const idx_t hash_bin_p) - : gsink(gsink), count(0), blocks(0), rows(std::move(sorted)), stage(WindowGroupStage::MASK), hash_bin(hash_bin_p), - masked(0), sunk(0), finalized(0), completed(0), batch_base(0) { +WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p) + : gsink(gsink), count(chunk_row.count), blocks(chunk_row.chunks), stage(WindowGroupStage::SORT), + hash_bin(hash_bin_p), sorted(0), materialized(0), masked(0), sunk(0), finalized(0), completed(0), batch_base(0) { // There are three types of partitions: // 1. No partition (no sorting) // 2. One partition (sorting, but no hashing) // 3. Multiple partitions (sorting and hashing) - // How big is the partition? - auto &gpart = *gsink.global_partition; - layout.Initialize(gpart.payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - - if (rows) { - count = rows->Count(); - blocks = rows->ChunkCount(); - } - // Set up the collection for any fully materialised data const auto &shared = WindowSharedExpressions::GetSortedExpressions(gsink.shared.coll_shared); vector types; @@ -497,7 +504,7 @@ unique_ptr WindowHashGroup::GetScanner(const idx_t return nullptr; } - auto &scan_ids = gsink.global_partition->scan_ids; + auto &scan_ids = gsink.sort_strategy->scan_ids; return make_uniq(*rows, scan_ids, begin_idx); } @@ -542,7 +549,6 @@ void WindowHashGroup::ComputeMasks(const idx_t block_begin, const idx_t block_en // If the data is unsorted, then the chunk sizes may be < STANDARD_VECTOR_SIZE, // and the entry range may be empty. if (begin_entry >= end_entry) { - D_ASSERT(gsink.global_partition->sort_col_count == 0); return; } @@ -558,22 +564,23 @@ void WindowHashGroup::ComputeMasks(const idx_t block_begin, const idx_t block_en } // If we are not sorting, then only the partition boundaries are needed. - if (!gsink.global_partition->sort) { + const auto &wexpr = gsink.op.select_list[gsink.op.order_idx]->Cast(); + auto &partitions = wexpr.partitions; + if (partitions.empty() && wexpr.orders.empty()) { return; } // Set up the partition compare structs - auto &partitions = gsink.global_partition->partitions; const auto key_count = partitions.size(); // Set up the order data structures auto &collection = *rows; - auto &scan_cols = gsink.global_partition->sort_ids; + auto &scan_cols = gsink.sort_strategy->sort_ids; WindowCollectionChunkScanner scanner(collection, scan_cols, block_begin); unordered_map prefixes; for (auto &order_mask : order_masks) { - D_ASSERT(order_mask.first >= partitions.size()); - auto order_type = scanner.PrefixStructType(order_mask.first, partitions.size()); + D_ASSERT(order_mask.first >= key_count); + auto order_type = scanner.PrefixStructType(order_mask.first, key_count); vector types(2, order_type); auto &keys = prefixes[order_mask.first]; // We can't use InitializeEmpty here because it doesn't set up all of the STRUCT internals... @@ -664,6 +671,10 @@ class WindowLocalSourceState : public LocalSourceState { DataChunk output_chunk; protected: + //! Sort the partition + void Sort(ExecutionContext &context, InterruptState &interrupt); + //! Materialize the sorted run + void Materialize(ExecutionContext &context, InterruptState &interrupt); //! Compute a mask range void Mask(ExecutionContext &context, InterruptState &interrupt); //! Sink tuples into function global states @@ -689,12 +700,50 @@ class WindowLocalSourceState : public LocalSourceState { idx_t WindowHashGroup::InitTasks(idx_t per_thread_p) { per_thread = per_thread_p; - group_threads = (rows->ChunkCount() + per_thread - 1) / per_thread; + group_threads = BinValue(ChunkCount(), per_thread); thread_states.resize(GetThreadCount()); return GetTaskCount(); } +void WindowLocalSourceState::Sort(ExecutionContext &context, InterruptState &interrupt) { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::SORT); + + auto &gsink = gsource.gsink; + auto &sort_strategy = *gsink.sort_strategy; + OperatorSinkFinalizeInput finalize {*gsink.strategy_sink, interrupt}; + sort_strategy.SortColumnData(context, task_local.group_idx, finalize); + + // Mark this range as done + window_hash_group->sorted += (task->end_idx - task->begin_idx); + task->begin_idx = task->end_idx; +} + +void WindowLocalSourceState::Materialize(ExecutionContext &context, InterruptState &interrupt) { + D_ASSERT(task); + D_ASSERT(task->stage == WindowGroupStage::MATERIALIZE); + + auto unused = make_uniq(); + OperatorSourceInput source {*gsource.hashed_source, *unused, interrupt}; + auto &gsink = gsource.gsink; + auto &sort_strategy = *gsink.sort_strategy; + sort_strategy.MaterializeColumnData(context, task_local.group_idx, source); + + // Mark this range as done + window_hash_group->materialized += (task->end_idx - task->begin_idx); + task->begin_idx = task->end_idx; + + // There is no good place to read the column data, + // and if we do it twice we can split the results. + if (window_hash_group->materialized >= window_hash_group->blocks) { + lock_guard prepare_guard(window_hash_group->lock); + if (!window_hash_group->rows) { + window_hash_group->rows = sort_strategy.GetColumnData(task_local.group_idx, source); + } + } +} + void WindowLocalSourceState::Mask(ExecutionContext &context, InterruptState &interrupt) { D_ASSERT(task); D_ASSERT(task->stage == WindowGroupStage::MASK); @@ -901,11 +950,6 @@ void WindowGlobalSourceState::FinishTask(TaskPtr task) { bool WindowLocalSourceState::TryAssignTask() { D_ASSERT(TaskFinished()); - if (task && task->stage == WindowGroupStage::GETDATA) { - // If this state completed the last block in the previous iteration, - // release our local state memory. - ReleaseLocalStates(); - } // Because downstream operators may be using our internal buffers, // we can't "finish" a task until we are about to get the next one. @@ -921,6 +965,14 @@ void WindowLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &r // Process the new state switch (task->stage) { + case WindowGroupStage::SORT: + Sort(context, interrupt); + D_ASSERT(TaskFinished()); + break; + case WindowGroupStage::MATERIALIZE: + Materialize(context, interrupt); + D_ASSERT(TaskFinished()); + break; case WindowGroupStage::MASK: Mask(context, interrupt); D_ASSERT(TaskFinished()); @@ -1056,8 +1108,8 @@ OperatorPartitionData PhysicalWindow::GetPartitionData(ExecutionContext &context return OperatorPartitionData(lstate.batch_index); } -SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &source) const { +SourceResultType PhysicalWindow::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &source) const { auto &gsource = source.global_state.Cast(); auto &lsource = source.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp index 1e186fa97..179c5bcbf 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp @@ -26,6 +26,10 @@ BaseScanner::BaseScanner(shared_ptr buffer_manager_p, shared_p } } +void BaseScanner::Print() const { + state_machine->Print(); +} + string BaseScanner::RemoveSeparator(const char *value_ptr, const idx_t size, char thousands_separator) { string result; result.reserve(size); diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 5ed14a992..bd6e9b546 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -22,7 +22,7 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_hander_p, CSVIterator &iterator_p, bool store_line_size_p, shared_ptr csv_file_scan_p, idx_t &lines_read_p, bool sniffing_p, - string path_p, idx_t scan_id) + const string &path_p, idx_t scan_id, bool &used_unstrictness) : ScannerResult(states, state_machine, result_size_p), number_of_columns(NumericCast(state_machine.dialect_options.num_cols)), null_padding(state_machine.options.null_padding), ignore_errors(state_machine.options.ignore_errors.GetValue()), @@ -30,8 +30,8 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m ? 0 : state_machine.dialect_options.state_machine_options.delimiter.GetValue().size() - 1), error_handler(error_hander_p), iterator(iterator_p), store_line_size(store_line_size_p), - csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), - current_errors(scan_id, state_machine.options.IgnoreErrors()), sniffing(sniffing_p), path(std::move(path_p)) { + csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), used_unstrictness(used_unstrictness), + current_errors(scan_id, state_machine.options.IgnoreErrors()), sniffing(sniffing_p), path(path_p) { // Vector information D_ASSERT(number_of_columns > 0); if (!buffer_handle) { @@ -154,23 +154,26 @@ inline bool IsValueNull(const char *null_str_ptr, const char *value_ptr, const i } bool StringValueResult::HandleTooManyColumnsError(const char *value_ptr, const idx_t size) { - if (cur_col_id >= number_of_columns && state_machine.state_machine_options.strict_mode.GetValue()) { - bool error = true; - if (cur_col_id == number_of_columns && ((quoted && state_machine.options.allow_quoted_nulls) || !quoted)) { - // we make an exception if the first over-value is null - bool is_value_null = false; - for (idx_t i = 0; i < null_str_count; i++) { - is_value_null = is_value_null || IsValueNull(null_str_ptr[i], value_ptr, size); + if (cur_col_id >= number_of_columns) { + if (state_machine.state_machine_options.strict_mode.GetValue()) { + bool error = true; + if (cur_col_id == number_of_columns && ((quoted && state_machine.options.allow_quoted_nulls) || !quoted)) { + // we make an exception if the first over-value is null + bool is_value_null = false; + for (idx_t i = 0; i < null_str_count; i++) { + is_value_null = is_value_null || IsValueNull(null_str_ptr[i], value_ptr, size); + } + error = !is_value_null; } - error = !is_value_null; - } - if (error) { - // We error pointing to the current value error. - current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); - cur_col_id++; + if (error) { + // We error pointing to the current value error. + current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); + cur_col_id++; + } + // We had an error + return true; } - // We had an error - return true; + used_unstrictness = true; } return false; } @@ -231,6 +234,7 @@ void StringValueResult::AddValueToVector(const char *value_ptr, idx_t size, bool } if (cur_col_id >= number_of_columns) { if (!state_machine.state_machine_options.strict_mode.GetValue()) { + used_unstrictness = true; return; } bool error = true; @@ -549,6 +553,7 @@ void StringValueResult::AddPossiblyEscapedValue(StringValueResult &result, const } if (result.cur_col_id >= result.number_of_columns && !result.state_machine.state_machine_options.strict_mode.GetValue()) { + result.used_unstrictness = true; return; } if (!result.HandleTooManyColumnsError(value_ptr, length)) { @@ -639,7 +644,6 @@ void StringValueResult::AddValue(StringValueResult &result, const idx_t buffer_p } void StringValueResult::HandleUnicodeError(idx_t col_idx, LinePosition &error_position) { - bool first_nl = false; auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); @@ -980,7 +984,7 @@ StringValueScanner::StringValueScanner(idx_t scanner_idx_p, const shared_ptrcontext), result_size, iterator.pos.buffer_pos, *error_handler, iterator, buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, - buffer_manager->GetFilePath(), scanner_idx_p), + buffer_manager->GetFilePath(), scanner_idx_p, used_unstrictness), start_pos(0) { if (scanner_idx == 0 && csv_file_scan) { lines_read += csv_file_scan->skipped_rows; @@ -997,7 +1001,7 @@ StringValueScanner::StringValueScanner(const shared_ptr &buffe result(states, *state_machine, cur_buffer_handle, Allocator::DefaultAllocator(), result_size, iterator.pos.buffer_pos, *error_handler, iterator, buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, - buffer_manager->GetFilePath(), 0), + buffer_manager->GetFilePath(), 0, used_unstrictness), start_pos(0) { if (scanner_idx == 0 && csv_file_scan) { lines_read += csv_file_scan->skipped_rows; @@ -1939,14 +1943,17 @@ void StringValueScanner::FinalizeChunkProcess() { if (result.current_errors.HandleErrors(result)) { result.number_of_rows++; } - if (states.IsQuotedCurrent() && !found_error && - state_machine->dialect_options.state_machine_options.strict_mode.GetValue()) { - type = UNTERMINATED_QUOTES; - // If we finish the execution of a buffer, and we end in a quoted state, it means we have unterminated - // quotes - result.current_errors.Insert(type, result.cur_col_id, result.chunk_col_id, result.last_position); - if (result.current_errors.HandleErrors(result)) { - result.number_of_rows++; + if (states.IsQuotedCurrent() && !found_error) { + if (state_machine->dialect_options.state_machine_options.strict_mode.GetValue()) { + type = UNTERMINATED_QUOTES; + // If we finish the execution of a buffer, and we end in a quoted state, it means we have unterminated + // quotes + result.current_errors.Insert(type, result.cur_col_id, result.chunk_col_id, result.last_position); + if (result.current_errors.HandleErrors(result)) { + result.number_of_rows++; + } + } else { + used_unstrictness = true; } } if (!iterator.done) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index bcaed8e5f..9cf087871 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -14,7 +14,7 @@ CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, const MultiFileOptions &file auto &logical_type = format_template.first; best_format_candidates[logical_type].clear(); } - // Initialize max columns found to either 0 or however many were set + // Initialize max columns found to either 0, or however many were set max_columns_found = set_columns.Size(); error_handler = make_shared_ptr(options.ignore_errors.GetValue()); detection_error_handler = make_shared_ptr(true); @@ -193,7 +193,8 @@ SnifferResult CSVSniffer::SniffCSV(const bool force_match) { buffer_manager->ResetBufferManager(); } buffer_manager->sniffing = false; - if (best_candidate->error_handler->AnyErrors() && !options.ignore_errors.GetValue()) { + if (best_candidate->error_handler->AnyErrors() && !options.ignore_errors.GetValue() && + best_candidate->state_machine->dialect_options.state_machine_options.strict_mode.GetValue()) { best_candidate->error_handler->ErrorIfTypeExists(MAXIMUM_LINE_SIZE); } D_ASSERT(best_sql_types_candidates_per_column_idx.size() == names.size()); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index fc8dc9385..880a06149 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -146,7 +146,6 @@ void CSVSniffer::GenerateStateMachineSearchSpace(vector(info_sql_types_candidates.size()) > static_cast(max_columns_found) * 0.7; const idx_t number_of_errors = candidate->error_handler->GetSize(); - if (!best_candidate || (varchar_cols(info_sql_types_candidates.size())>( - static_cast(max_columns_found) * 0.7) && - (!options.ignore_errors.GetValue() || number_of_errors < min_errors))) { + const bool better_strictness = best_candidate_is_strict ? !candidate->used_unstrictness : true; + const bool acceptable_candidate = has_less_varchar_cols && acceptable_best_num_cols && better_strictness; + // If we escaped an unquoted character when strict is false. + if (!best_candidate || + (acceptable_candidate && (!options.ignore_errors.GetValue() || number_of_errors < min_errors))) { min_errors = number_of_errors; best_header_row.clear(); // we have a new best_options candidate best_candidate = std::move(candidate); min_varchar_cols = varchar_cols; + best_candidate_is_strict = !best_candidate->used_unstrictness; best_sql_types_candidates_per_column_idx = info_sql_types_candidates; for (auto &format_candidate : format_candidates) { best_format_candidates[format_candidate.first] = format_candidate.second.format; diff --git a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp index a8ac5a53f..e78f8f99b 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp @@ -495,7 +495,6 @@ const StateMachine &CSVStateMachineCache::Get(const CSVStateMachineOptions &stat } CSVStateMachineCache &CSVStateMachineCache::Get(ClientContext &context) { - auto &cache = ObjectCache::GetObjectCache(context); return *cache.GetOrCreate(CSVStateMachineCache::ObjectType()); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index 8a830c8c9..e293c7337 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -13,7 +13,6 @@ CSVFileScan::CSVFileScan(ClientContext &context, const OpenFileInfo &file_p, CSV : BaseFileReader(file_p), buffer_manager(std::move(buffer_manager_p)), error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(std::move(options_p)) { - // Initialize Buffer Manager if (!buffer_manager) { buffer_manager = make_shared_ptr(context, options, file, per_file_single_threaded); diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp index 260ac35d4..843209eb7 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp @@ -365,13 +365,15 @@ bool CSVFileScan::TryInitializeScan(ClientContext &context, GlobalTableFunctionS return true; } -void CSVFileScan::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &chunk) { +AsyncResult CSVFileScan::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) { auto &lstate = local_state.Cast(); if (lstate.csv_reader->FinishedIterator()) { - return; + return AsyncResult(SourceResultType::FINISHED); } lstate.csv_reader->Flush(chunk); + return chunk.size() == 0 ? AsyncResult(SourceResultType::FINISHED) + : AsyncResult(SourceResultType::HAVE_MORE_OUTPUT); } void CSVFileScan::FinishFile(ClientContext &context, GlobalTableFunctionState &global_state) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp index 7fd64d889..07558c6bc 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -275,7 +275,7 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, LinesPerBoundary CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx_p, string csv_row_p, LinesPerBoundary error_info_p, idx_t row_byte_position, optional_idx byte_position_p, - const CSVReaderOptions &reader_options, const string &fixes, const string ¤t_path) + const CSVReaderOptions &reader_options, const string &fixes, const String ¤t_path) : error_message(std::move(error_message_p)), type(type_p), column_idx(column_idx_p), csv_row(std::move(csv_row_p)), error_info(error_info_p), row_byte_position(row_byte_position), byte_position(byte_position_p) { // What were the options @@ -319,7 +319,7 @@ void CSVError::RemoveNewLine(string &error) { CSVError CSVError::CastError(const CSVReaderOptions &options, const string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type, const string ¤t_path) { + optional_idx byte_position, LogicalTypeId type, const String ¤t_path) { std::ostringstream error; // Which column error << "Error when converting column \"" << column_name << "\". "; @@ -350,7 +350,7 @@ CSVError CSVError::CastError(const CSVReaderOptions &options, const string &colu } CSVError CSVError::LineSizeError(const CSVReaderOptions &options, LinesPerBoundary error_info, string &csv_row, - idx_t byte_position, const string ¤t_path) { + idx_t byte_position, const String ¤t_path) { std::ostringstream error; error << "Maximum line size of " << options.maximum_line_size.GetValue() << " bytes exceeded. "; error << "Actual Size:" << csv_row.size() << " bytes." << '\n'; @@ -365,7 +365,7 @@ CSVError CSVError::LineSizeError(const CSVReaderOptions &options, LinesPerBounda CSVError CSVError::InvalidState(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; error << "The CSV Parser state machine reached an invalid state.\nThis can happen when is not possible to parse " "your CSV File with the given options, or the CSV File is not RFC 4180 compliant "; @@ -521,7 +521,7 @@ CSVError CSVError::SniffingError(const CSVReaderOptions &options, const string & } CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; error << " The parallel scanner does not support null_padding in conjunction with quoted new lines. Please " "disable the parallel csv reader with parallel=false" @@ -533,7 +533,7 @@ CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoun CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { + optional_idx byte_position, const String ¤t_path) { std::ostringstream error; error << "Value with unterminated quote found." << '\n'; std::ostringstream how_to_fix_it; @@ -551,7 +551,7 @@ CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { + optional_idx byte_position, const String ¤t_path) { std::ostringstream error; // We don't have a fix for this std::ostringstream how_to_fix_it; @@ -581,7 +581,7 @@ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, i CSVError CSVError::InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { + const String ¤t_path) { std::ostringstream error; // How many columns were expected and how many were found error << "Invalid unicode (byte sequence mismatch) detected. This file is not " << options.encoding << " encoded." diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index 5801e99b0..75a680b3b 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/common/multi_file/multi_file_reader.hpp" #include "duckdb/common/set.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { @@ -465,7 +466,7 @@ bool CSVReaderOptions::WasTypeManuallySet(idx_t i) const { return was_type_manually_set[i]; } -string CSVReaderOptions::ToString(const string ¤t_file_path) const { +string CSVReaderOptions::ToString(const String ¤t_file_path) const { auto &delimiter = dialect_options.state_machine_options.delimiter; auto "e = dialect_options.state_machine_options.quote; auto &escape = dialect_options.state_machine_options.escape; @@ -475,7 +476,7 @@ string CSVReaderOptions::ToString(const string ¤t_file_path) const { auto &skip_rows = dialect_options.skip_rows; auto &header = dialect_options.header; - string error = " file = " + current_file_path + "\n "; + string error = " file = " + current_file_path.ToStdString() + "\n "; // Let's first print options that can either be set by the user or by the sniffer // delimiter error += FormatOptionLine("delimiter", delimiter); diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp index 2921a0e83..889667e82 100644 --- a/src/duckdb/src/execution/operator/filter/physical_filter.cpp +++ b/src/duckdb/src/execution/operator/filter/physical_filter.cpp @@ -7,7 +7,6 @@ namespace duckdb { PhysicalFilter::PhysicalFilter(PhysicalPlan &physical_plan, vector types, vector> select_list, idx_t estimated_cardinality) : CachingPhysicalOperator(physical_plan, PhysicalOperatorType::FILTER, std::move(types), estimated_cardinality) { - D_ASSERT(!select_list.empty()); if (select_list.size() == 1) { expression = std::move(select_list[0]); diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp index c36e891f2..e79a4d044 100644 --- a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp @@ -47,7 +47,7 @@ unique_ptr PhysicalBatchCollector::GetGlobalSinkState(ClientCon return make_uniq(context, *this); } -unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); D_ASSERT(gstate.result); return std::move(gstate.result); diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp index 404d14343..9e3caab6c 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp @@ -53,7 +53,6 @@ SinkResultType PhysicalBufferedBatchCollector::Sink(ExecutionContext &context, D SinkNextBatchType PhysicalBufferedBatchCollector::NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &input) const { - auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); @@ -94,11 +93,11 @@ unique_ptr PhysicalBufferedBatchCollector::GetLocalSinkState(Exe unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } -unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); auto cc = gstate.context.lock(); auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp index 7795230dc..8ee1e1617 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp @@ -48,7 +48,7 @@ SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &conte unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); + state->buffered_data = make_shared_ptr(context); return std::move(state); } @@ -57,7 +57,7 @@ unique_ptr PhysicalBufferedCollector::GetLocalSinkState(Executio return std::move(state); } -unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); lock_guard l(gstate.glock); // FIXME: maybe we want to check if the execution was successful before creating the StreamQueryResult ? diff --git a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp b/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp index ac06b3950..b2cd0eca3 100644 --- a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp @@ -5,8 +5,8 @@ namespace duckdb { -SourceResultType PhysicalCreateSecret::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateSecret::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &client = context.client; auto &secret_manager = SecretManager::Get(client); diff --git a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp index 8d2c12811..1404ea0e9 100644 --- a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp @@ -32,8 +32,8 @@ unique_ptr PhysicalExplainAnalyze::GetGlobalSinkState(ClientCon //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalExplainAnalyze::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalExplainAnalyze::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); chunk.SetValue(0, 0, Value("analyzed_plan")); diff --git a/src/duckdb/src/execution/operator/helper/physical_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_limit.cpp index 5a4339c63..d36bd60d6 100644 --- a/src/duckdb/src/execution/operator/helper/physical_limit.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_limit.cpp @@ -8,6 +8,8 @@ namespace duckdb { +constexpr const idx_t PhysicalLimit::MAX_LIMIT_VALUE; + PhysicalLimit::PhysicalLimit(PhysicalPlan &physical_plan, vector types, BoundLimitNode limit_val_p, BoundLimitNode offset_val_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), @@ -19,7 +21,8 @@ PhysicalLimit::PhysicalLimit(PhysicalPlan &physical_plan, vector ty //===--------------------------------------------------------------------===// class LimitGlobalState : public GlobalSinkState { public: - explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) : data(context, op.types, true) { + explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) + : data(context, op.types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { limit = 0; offset = 0; } @@ -33,7 +36,7 @@ class LimitGlobalState : public GlobalSinkState { class LimitLocalState : public LocalSinkState { public: explicit LimitLocalState(ClientContext &context, const PhysicalLimit &op) - : current_offset(0), data(context, op.types, true) { + : current_offset(0), data(context, op.types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { PhysicalLimit::SetInitialLimits(op.limit_val, op.offset_val, limit, offset); } @@ -108,7 +111,6 @@ bool PhysicalLimit::ComputeOffset(ExecutionContext &context, DataChunk &input, o } SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - D_ASSERT(chunk.size() > 0); auto &state = input.local_state.Cast(); auto &limit = state.limit; @@ -165,7 +167,8 @@ unique_ptr PhysicalLimit::GetGlobalSourceState(ClientContext return make_uniq(); } -SourceResultType PhysicalLimit::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalLimit::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); auto &state = input.global_state.Cast(); while (state.current_offset < gstate.limit + gstate.offset) { diff --git a/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp b/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp index dfb2eef26..20db72032 100644 --- a/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp @@ -118,8 +118,8 @@ unique_ptr PhysicalLimitPercent::GetGlobalSourceState(ClientC return make_uniq(*this); } -SourceResultType PhysicalLimitPercent::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalLimitPercent::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); auto &state = input.global_state.Cast(); auto &percent_limit = gstate.limit_percent; diff --git a/src/duckdb/src/execution/operator/helper/physical_load.cpp b/src/duckdb/src/execution/operator/helper/physical_load.cpp index 5f0e7a027..d6bcda546 100644 --- a/src/duckdb/src/execution/operator/helper/physical_load.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_load.cpp @@ -25,7 +25,8 @@ static void InstallFromRepository(ClientContext &context, const LoadInfo &info) ExtensionHelper::InstallExtension(context, info.filename, options); } -SourceResultType PhysicalLoad::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalLoad::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { if (info->load_type == LoadType::INSTALL || info->load_type == LoadType::FORCE_INSTALL) { if (info->repository.empty()) { ExtensionInstallOptions options; diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp index a70b914ce..a42e7a4a8 100644 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -2,6 +2,7 @@ #include "duckdb/main/materialized_query_result.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -10,6 +11,19 @@ PhysicalMaterializedCollector::PhysicalMaterializedCollector(PhysicalPlan &physi : PhysicalResultCollector(physical_plan, data), parallel(parallel) { } +class MaterializedCollectorGlobalState : public GlobalSinkState { +public: + mutex glock; + unique_ptr collection; + shared_ptr context; +}; + +class MaterializedCollectorLocalState : public LocalSinkState { +public: + unique_ptr collection; + ColumnDataAppendState append_state; +}; + SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &lstate = input.local_state.Cast(); @@ -43,15 +57,15 @@ unique_ptr PhysicalMaterializedCollector::GetGlobalSinkState(Cl unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(ExecutionContext &context) const { auto state = make_uniq(); - state->collection = make_uniq(Allocator::DefaultAllocator(), types); + state->collection = CreateCollection(context.client); state->collection->InitializeAppend(state->append_state); return std::move(state); } -unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) { +unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); if (!gstate.collection) { - gstate.collection = make_uniq(Allocator::DefaultAllocator(), types); + gstate.collection = CreateCollection(*gstate.context); } auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), gstate.context->GetClientProperties()); diff --git a/src/duckdb/src/execution/operator/helper/physical_pragma.cpp b/src/duckdb/src/execution/operator/helper/physical_pragma.cpp index 34a5cb6fb..df48d7f8d 100644 --- a/src/duckdb/src/execution/operator/helper/physical_pragma.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_pragma.cpp @@ -2,8 +2,8 @@ namespace duckdb { -SourceResultType PhysicalPragma::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalPragma::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &client = context.client; FunctionParameters parameters {info->parameters, info->named_parameters}; info->function.function(client, parameters); diff --git a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp index 784d6ada4..7603b6e87 100644 --- a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp @@ -3,12 +3,12 @@ namespace duckdb { -SourceResultType PhysicalPrepare::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalPrepare::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &client = context.client; // store the prepared statement in the context - ClientData::Get(client).prepared_statements[name] = prepared; + ClientData::Get(client).prepared_statements[name.ToStdString()] = prepared; return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp index 7b89c7a14..193579123 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp @@ -78,8 +78,8 @@ SinkFinalizeType PhysicalReservoirSample::Finalize(Pipeline &pipeline, Event &ev //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalReservoirSample::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &sink = this->sink_state->Cast(); lock_guard glock(sink.lock); if (!sink.sample) { diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp index 1f5baf75d..4402d6e3e 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -19,7 +19,8 @@ void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig & } } -SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalReset::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { if (scope == SetScope::VARIABLE) { auto &client_config = ClientConfig::GetConfig(context.client); client_config.ResetUserVariable(name); @@ -36,8 +37,7 @@ SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &ch auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { - throw InvalidInputException("Extension parameter %s was not found after autoloading", - name.ToStdString()); + throw InvalidInputException("Extension parameter %s was not found after autoloading", name); } } ResetExtensionVariable(context, config, entry->second); diff --git a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp index d78bf225b..df95ce707 100644 --- a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp @@ -10,13 +10,15 @@ #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/parallel/pipeline.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { PhysicalResultCollector::PhysicalResultCollector(PhysicalPlan &physical_plan, PreparedStatementData &data) : PhysicalOperator(physical_plan, PhysicalOperatorType::RESULT_COLLECTOR, {LogicalType::BOOLEAN}, 0), - statement_type(data.statement_type), properties(data.properties), plan(data.physical_plan->Root()), - names(data.names) { + statement_type(data.statement_type), properties(data.properties), memory_type(data.memory_type), + plan(data.physical_plan->Root()), names(data.names) { types = data.types; } @@ -26,7 +28,7 @@ PhysicalOperator &PhysicalResultCollector::GetResultCollector(ClientContext &con if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, root)) { // Not an order-preserving plan: use the parallel materialized collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data, true); } return physical_plan.Make(data, true); @@ -34,14 +36,14 @@ PhysicalOperator &PhysicalResultCollector::GetResultCollector(ClientContext &con if (!PhysicalPlanGenerator::UseBatchIndex(context, root)) { // Order-preserving plan, and we cannot use the batch index: use single-threaded result collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data, false); } return physical_plan.Make(data, false); } // Order-preserving plan, and we can use the batch index: use a batch collector. - if (data.is_streaming) { + if (data.output_type == QueryResultOutputType::ALLOW_STREAMING) { return physical_plan.Make(data); } return physical_plan.Make(data); @@ -66,4 +68,18 @@ void PhysicalResultCollector::BuildPipelines(Pipeline ¤t, MetaPipeline &me child_meta_pipeline.Build(plan); } +unique_ptr PhysicalResultCollector::CreateCollection(ClientContext &context) const { + switch (memory_type) { + case QueryResultMemoryType::IN_MEMORY: + return make_uniq(Allocator::DefaultAllocator(), types); + case QueryResultMemoryType::BUFFER_MANAGED: + // Use the DatabaseInstance BufferManager because the query result can outlive the ClientContext + return make_uniq(BufferManager::GetBufferManager(*context.db), types, + ColumnDataCollectionLifetime::THROW_ERROR_AFTER_DATABASE_CLOSES); + default: + throw NotImplementedException("PhysicalResultCollector::CreateCollection for %s", + EnumUtil::ToString(memory_type)); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp index e8362ad9c..82f505113 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set.cpp @@ -6,17 +6,17 @@ namespace duckdb { -void PhysicalSet::SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value) { +void PhysicalSet::SetGenericVariable(ClientContext &context, const String &name, SetScope scope, Value target_value) { if (scope == SetScope::GLOBAL) { auto &config = DBConfig::GetConfig(context); config.SetOption(name, std::move(target_value)); } else { auto &client_config = ClientConfig::GetConfig(context); - client_config.set_variables[name] = std::move(target_value); + client_config.set_variables[name.ToStdString()] = std::move(target_value); } } -void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, +void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const String &name, SetScope scope, const Value &value) { auto &target_type = extension_option.type; Value target_value = value.CastAs(context, target_type); @@ -29,17 +29,18 @@ void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption & SetGenericVariable(context, name, scope, std::move(target_value)); } -SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalSet::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &config = DBConfig::GetConfig(context.client); // check if we are allowed to change the configuration option config.CheckLock(name); auto option = DBConfig::GetOptionByName(name); if (!option) { // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); + auto entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); + entry = config.extension_parameters.find(name.ToStdString()); if (entry == config.extension_parameters.end()) { throw InvalidInputException("Extension parameter %s was not found after autoloading", name); } diff --git a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp index 430e7055e..7813ef0e2 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp @@ -1,16 +1,17 @@ #include "duckdb/execution/operator/helper/physical_set_variable.hpp" #include "duckdb/main/client_config.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { -PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, string name_p, idx_t estimated_cardinality) +PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET_VARIABLE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(std::move(name_p)) { + name(physical_plan.ArenaRef().MakeString(name_p)) { } -SourceResultType PhysicalSetVariable::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalSetVariable::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp index 2ad12b09a..8e34f10b7 100644 --- a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp @@ -11,8 +11,8 @@ namespace duckdb { -SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalTransaction::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &client = context.client; auto type = info->type; diff --git a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp b/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp index d29632d12..30e1b28bd 100644 --- a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp @@ -3,8 +3,8 @@ namespace duckdb { -SourceResultType PhysicalUpdateExtensions::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalUpdateExtensions::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &data = input.global_state.Cast(); if (data.offset >= data.update_result_entries.size()) { diff --git a/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp b/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp index e46bc54fd..d90e4964c 100644 --- a/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp @@ -105,8 +105,8 @@ SinkFinalizeType PhysicalVacuum::Finalize(Pipeline &pipeline, Event &event, Clie return SinkFinalizeType::READY; } -SourceResultType PhysicalVacuum::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalVacuum::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { // NOP return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp index a48eaee4f..a92a2feae 100644 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -125,14 +125,14 @@ bool PerfectHashJoinExecutor::CanDoPerfectHashJoin(const PhysicalHashJoin &op, c //===--------------------------------------------------------------------===// bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { // First, allocate memory for each build column - auto build_size = perfect_join_statistics.build_range + 1; + const auto build_size = perfect_join_statistics.build_range + 1; for (const auto &type : join.rhs_output_columns.col_types) { - perfect_hash_table.emplace_back(type, build_size); + perfect_hash_table.emplace_back(DictionaryVector::CreateReusableDictionary(type, build_size)); } // and for duplicate_checking - bitmap_build_idx = make_unsafe_uniq_array_uninitialized(build_size); - memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false + bitmap_build_idx.Initialize(build_size); + bitmap_build_idx.SetAllInvalid(build_size); // Now fill columns with build data return FullScanHashTable(key_type); @@ -143,20 +143,8 @@ bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { // TODO: In a parallel finalize: One should exclusively lock and each thread should do one part of the code below. Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples - - idx_t key_count = 0; - if (data_collection.ChunkCount() > 0) { - JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); - - // Go through all the blocks and fill the keys addresses - key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); - } - - // Scan the build keys in the hash table - Vector build_vector(key_type, key_count); - data_collection.Gather(tuples_addresses, *FlatVector::IncrementalSelectionVector(), key_count, 0, build_vector, - *FlatVector::IncrementalSelectionVector(), nullptr); + Vector build_vector(key_type, ht.Count()); + auto key_count = ht.ScanKeyColumn(tuples_addresses, build_vector, 0); // Now fill the selection vector using the build keys and create a sequential vector // TODO: add check for fast pass when probe is part of build domain @@ -168,22 +156,25 @@ bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { if (!success) { return false; } - if (unique_keys == perfect_join_statistics.build_range + 1 && !ht.has_null) { + + const auto build_size = perfect_join_statistics.build_range + 1; + if (unique_keys == build_size && !ht.has_null) { perfect_join_statistics.is_build_dense = true; + bitmap_build_idx.Reset(build_size); // All valid } key_count = unique_keys; // do not consider keys out of the range // Full scan the remaining build columns and fill the perfect hash table - const auto build_size = perfect_join_statistics.build_range + 1; + for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { - auto &vector = perfect_hash_table[i]; + auto &vector = perfect_hash_table[i]->data; const auto output_col_idx = ht.output_columns[i]; D_ASSERT(vector.GetType() == ht.layout_ptr->GetTypes()[output_col_idx]); - if (build_size > STANDARD_VECTOR_SIZE) { - auto &col_mask = FlatVector::Validity(vector); - col_mask.Initialize(build_size); - } + auto &col_mask = FlatVector::Validity(vector); + col_mask.Reset(build_size); data_collection.Gather(tuples_addresses, sel_tuples, key_count, output_col_idx, vector, sel_build, nullptr); + // This ensures the empty entries are set to NULL, so that the emitted dictionary vectors make sense + col_mask.Combine(bitmap_build_idx, build_size); } return true; @@ -227,19 +218,19 @@ bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); UnifiedVectorFormat vector_data; source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); + const auto data = vector_data.GetData(); // generate the selection vector for (idx_t i = 0, sel_idx = 0; i < count; ++i) { auto data_idx = vector_data.sel->get_index(i); auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx position sel_vec.set_index(sel_idx, idx); - if (bitmap_build_idx[idx]) { + if (bitmap_build_idx.RowIsValidUnsafe(idx)) { return false; } else { - bitmap_build_idx[idx] = true; + bitmap_build_idx.SetValidUnsafe(idx); unique_keys++; } seq_sel_vec.set_index(sel_idx++, i); @@ -302,9 +293,7 @@ OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionConte for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { auto &result_vector = result.data[lhs_output_columns.ColumnCount() + i]; D_ASSERT(result_vector.GetType() == ht.layout_ptr->GetTypes()[ht.output_columns[i]]); - auto &build_vec = perfect_hash_table[i]; - result_vector.Reference(build_vec); - result_vector.Slice(state.build_sel_vec, probe_sel_count); + result_vector.Dictionary(perfect_hash_table[i], state.build_sel_vec); } return OperatorResultType::NEED_MORE_INPUT; } @@ -367,9 +356,9 @@ void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx + // position check for matches in the build + if (bitmap_build_idx.RowIsValid(idx)) { build_sel_vec.set_index(sel_idx, idx); probe_sel_vec.set_index(sel_idx++, i); probe_sel_count++; @@ -386,9 +375,9 @@ void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, auto input_value = data[data_idx]; // add index to selection vector if value in the range if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { + auto idx = UnsafeNumericCast(input_value - min_value); // subtract min value to get the idx + // position check for matches in the build + if (bitmap_build_idx.RowIsValid(idx)) { build_sel_vec.set_index(sel_idx, idx); probe_sel_vec.set_index(sel_idx++, i); probe_sel_count++; diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index 719781992..ab5a450f2 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -1,14 +1,15 @@ #include "duckdb/execution/operator/join/physical_asof_join.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/sorting/sort_strategy.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/function/create_sort_key.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" namespace duckdb { @@ -16,9 +17,9 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso PhysicalOperator &right) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, op.estimated_cardinality), - comparison_type(ExpressionType::INVALID), predicate(std::move(op.predicate)) { - + comparison_type(ExpressionType::INVALID) { // Convert the conditions partitions and sorts + D_ASSERT(!op.predicate.get()); for (auto &cond : conditions) { D_ASSERT(cond.left->return_type == cond.right->return_type); join_key_types.push_back(cond.left->return_type); @@ -74,51 +75,44 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso //===--------------------------------------------------------------------===// class AsOfGlobalSinkState : public GlobalSinkState { public: - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) - : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1].get().GetTypes(), {}, - op.estimated_cardinality), - is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { - } - - idx_t Count() const { - return rhs_sink.count; + using SortStrategyPtr = unique_ptr; + using SortStrategySinkPtr = unique_ptr; + using PartitionMarkers = vector; + + AsOfGlobalSinkState(ClientContext &client, const PhysicalAsOfJoin &op) { + // Set up partitions for both sides + sort_strategies.reserve(2); + strategy_sinks.reserve(2); + const vector> partitions_stats; + auto &lhs = op.children[0].get(); + auto sort = SortStrategy::Factory(client, op.lhs_partitions, op.lhs_orders, lhs.GetTypes(), partitions_stats, + lhs.estimated_cardinality, true); + strategy_sinks.emplace_back(sort->GetGlobalSinkState(client)); + sort_strategies.emplace_back(std::move(sort)); + + auto &rhs = op.children[1].get(); + sort = SortStrategy::Factory(client, op.rhs_partitions, op.rhs_orders, rhs.GetTypes(), partitions_stats, + rhs.estimated_cardinality, true); + strategy_sinks.emplace_back(sort->GetGlobalSinkState(client)); + sort_strategies.emplace_back(std::move(sort)); } - PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { - lock_guard guard(lock); - lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); - return lhs_buffers.back().get(); - } - - PartitionGlobalSinkState rhs_sink; - - // One per partition - const bool is_outer; - vector right_outers; - bool has_null; - - // Left side buffering - unique_ptr lhs_sink; - - mutex lock; - vector> lhs_buffers; + //! The child that is being materialised (right/1 then left/0) + size_t child = 1; + //! The child's partitioning description + vector sort_strategies; + //! The child's partitioning buffer + vector strategy_sinks; }; class AsOfLocalSinkState : public LocalSinkState { public: - explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : local_partition(context, gstate_p) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); + AsOfLocalSinkState(ExecutionContext &context, AsOfGlobalSinkState &gsink) { + auto &sort_strategy = *gsink.sort_strategies[gsink.child]; + local_partition = sort_strategy.GetLocalSinkState(context); } - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; + unique_ptr local_partition; }; unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { @@ -126,411 +120,990 @@ unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext & } unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink.rhs_sink); + return make_uniq(context, gsink); } -SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); +SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &sink) const { + auto &gstate = sink.global_state.Cast(); + auto &lstate = sink.local_state.Cast(); - lstate.Sink(chunk); + auto &sort_strategy = *gstate.sort_strategies[gstate.child]; + auto &gsink = *gstate.strategy_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; - return SinkResultType::NEED_MORE_INPUT; + OperatorSinkInput hsink {gsink, lsink, sink.interrupt_state}; + return sort_strategy.Sink(context, chunk, hsink); } -SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - return SinkCombineResultType::FINISHED; +SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { + auto &gstate = combine.global_state.Cast(); + auto &lstate = combine.local_state.Cast(); + + auto &sort_strategy = *gstate.sort_strategies[gstate.child]; + auto &gsink = *gstate.strategy_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; + + OperatorSinkCombineInput hcombine {gsink, lsink, combine.interrupt_state}; + return sort_strategy.Combine(context, hcombine); } //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - // The data is all in so we can initialise the left partitioning. - const vector> partitions_stats; - gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, - children[0].get().GetTypes(), partitions_stats, 0U); - gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); - - // Find the first group to sort - if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; +SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, + OperatorSinkFinalizeInput &finalize) const { + auto &gstate = finalize.global_state.Cast(); + + // The data is all in so we can synchronise the left partitioning. + auto &sort_strategy = *gstate.sort_strategies[gstate.child]; + auto &hashed_sink = *gstate.strategy_sinks[gstate.child]; + OperatorSinkFinalizeInput hfinalize {hashed_sink, finalize.interrupt_state}; + if (gstate.child == 1) { + auto &lhs_groups = *gstate.strategy_sinks[1 - gstate.child]; + auto &rhs_groups = hashed_sink; + sort_strategy.Synchronize(rhs_groups, lhs_groups); } - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline, *this); - event.InsertEvent(std::move(new_event)); + // Switch sides + gstate.child = 1 - gstate.child; - return SinkFinalizeType::READY; + return sort_strategy.Finalize(client, hfinalize); +} + +OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, + GlobalOperatorState &gstate, OperatorState &lstate_p) const { + return OperatorResultType::FINISHED; } //===--------------------------------------------------------------------===// -// Operator +// Source //===--------------------------------------------------------------------===// -class AsOfGlobalState : public GlobalOperatorState { +enum class AsOfJoinSourceStage : uint8_t { INIT, SORT, MATERIALIZE, GET, LEFT, RIGHT, DONE }; + +struct AsOfSourceTask { + AsOfSourceTask() { + } + + AsOfJoinSourceStage stage = AsOfJoinSourceStage::DONE; + //! The hash group + idx_t group_idx = 0; + //! The thread index (for local state) + idx_t thread_idx = 0; + //! The total block index count + idx_t max_idx = 0; + //! The first block index count + idx_t begin_idx = 0; + //! The last block index count + idx_t end_idx = 0; +}; + +class AsOfPayloadScanner { public: - explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.rhs_sink; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); + using Types = vector; + using Columns = vector; + + AsOfPayloadScanner(const SortedRun &sorted_run, const SortStrategy &sort_strategy); + idx_t Base() const { + return base; + } + idx_t Scanned() const { + return scanned; + } + idx_t Remaining() const { + return count - scanned; + } + idx_t NextSize() const { + return MinValue(Remaining(), STANDARD_VECTOR_SIZE); + } + void SeekBlock(idx_t block_idx) { + chunk_idx = block_idx; + base = MinValue(chunk_idx * STANDARD_VECTOR_SIZE, count); + scanned = base; + } + inline void SeekRow(idx_t row_idx) { + SeekBlock(row_idx / STANDARD_VECTOR_SIZE); + } + bool Scan(DataChunk &chunk) { + // Free the previous blocks + block_state.SetKeepPinned(true); + block_state.SetPinPayload(true); + + base = scanned; + const auto result = (this->*scan_func)(); + chunk.ReferenceColumns(scan_chunk, scan_ids); + scanned += scan_chunk.size(); + ++chunk_idx; + return result; + } + +private: + template + bool TemplatedScan() { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(block_state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto result_count = NextSize(); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = block_state.GetIndex(chunk_idx, i); + sort_keys[i] = &itr[idx]; } + + // Scan + scan_chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, scan_chunk); + return scan_chunk.size() > 0; } + + // Only figure out the scan function once. + using scan_t = bool (duckdb::AsOfPayloadScanner::*)(); + scan_t scan_func; + + const SortedRun &sorted_run; + ExternalBlockIteratorState block_state; + Vector sort_key_pointers = Vector(LogicalType::POINTER); + SortedRunScanState scan_state; + const Columns scan_ids; + DataChunk scan_chunk; + const idx_t count; + idx_t base = 0; + idx_t scanned = 0; + idx_t chunk_idx = 0; }; -unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); +AsOfPayloadScanner::AsOfPayloadScanner(const SortedRun &sorted_run, const SortStrategy &sort_strategy) + : sorted_run(sorted_run), block_state(*sorted_run.key_data, sorted_run.payload_data.get()), + scan_state(sorted_run.context, sorted_run.sort), scan_ids(sort_strategy.scan_ids), count(sorted_run.Count()) { + scan_chunk.Initialize(sorted_run.context, sort_strategy.payload_types); + const auto sort_key_type = sorted_run.key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + default: + throw NotImplementedException("AsOfPayloadScanner for %s", EnumUtil::ToString(sort_key_type)); + } } -class AsOfLocalState : public CachingOperatorState { +class AsOfHashGroup { public: - AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), - left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { - lhs_keys.Initialize(allocator, op.join_key_types); - for (const auto &cond : op.conditions) { - lhs_executor.AddExpression(*cond.left); - } + using HashGroupPtr = unique_ptr; + using ChunkRow = SortStrategy::ChunkRow; + + template + static T BinValue(T n, T val) { + return ((n + (val - 1)) / val); + } - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); + AsOfHashGroup(const PhysicalAsOfJoin &op, const ChunkRow &left_stats, const ChunkRow &right_stats, + const idx_t hash_group); - auto &gsink = op.sink_state->Cast(); - lhs_partition_sink = gsink.RegisterBuffer(context); + //! Is this a right join (do we have a RIGHT stage?) + inline bool IsRightOuter() const { + return right_outer.Enabled(); } - bool Sink(DataChunk &input); - OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); + //! The processing stage for this group + AsOfJoinSourceStage GetStage() const { + return stage; + } - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; + //! The total number of tasks we will execute + idx_t GetTaskCount() const { + return stage_begin[size_t(AsOfJoinSourceStage::DONE)]; + } - ExpressionExecutor lhs_executor; - DataChunk lhs_keys; - ValidityMask lhs_valid_mask; - SelectionVector lhs_sel; - DataChunk lhs_payload; + //! The number of left chunks + inline idx_t LeftChunks() const { + return left_stats.chunks; + } - OuterJoinMarker left_outer; - bool fetch_next_left; + //! The number of right chunks + inline idx_t RightChunks() const { + return right_stats.chunks; + } - optional_ptr lhs_partition_sink; + // Set up the task parameters + idx_t InitTasks(idx_t per_thread); + + //! The maximum number of chunks that we will scan for each state + idx_t MaximumChunks() const { + return MaxValue(LeftChunks(), RightChunks()); + } + + //! Try to move to the next stage + bool TryPrepareNextStage(); + //! Try to get another task for this group + bool TryNextTask(AsOfSourceTask &task); + //! Finish the given task. Returns true if there are no more tasks. + bool FinishTask(AsOfSourceTask &task); + + //! The parent operator + const PhysicalAsOfJoin &op; + //! The group number + const idx_t group_idx; + //! The number of left chunks/rows + const ChunkRow left_stats; + //! The number of right chunks/rows + const ChunkRow right_stats; + //! The left hash partition data + HashGroupPtr left_group; + //! The right hash partition data + HashGroupPtr right_group; + //! The right outer join markers + OuterJoinMarker right_outer; + // The processing stage for this group + AsOfJoinSourceStage stage; + //! The the number of blocks per thread. + idx_t per_thread = 0; + //! The the number of tasks per stage. + vector stage_tasks; + //! The the first task in the stage. + vector stage_begin; + //! The next task to process + idx_t next_task = 0; + //! Count of sorting tasks completed + std::atomic sorted; + //! Count of materialization tasks completed + std::atomic materialized; + //! Count of get tasks completed + std::atomic gotten; + //! Count of left side tasks completed + std::atomic left_completed; + //! Count of right side tasks completed + std::atomic right_completed; }; -bool AsOfLocalState::Sink(DataChunk &input) { - // Compute the join keys - lhs_keys.Reset(); - lhs_executor.Execute(input, lhs_keys); - lhs_keys.Flatten(); +AsOfHashGroup::AsOfHashGroup(const PhysicalAsOfJoin &op, const ChunkRow &left_stats, const ChunkRow &right_stats, + const idx_t hash_group) + : op(op), group_idx(hash_group), left_stats(left_stats), right_stats(right_stats), + right_outer(IsRightOuterJoin(op.join_type)), stage(AsOfJoinSourceStage::INIT), sorted(0), materialized(0), + gotten(0), left_completed(0), right_completed(0) { + right_outer.Initialize(right_stats.count); +}; - // Combine the NULLs - const auto count = input.size(); - lhs_valid_mask.Reset(); - for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; - UnifiedVectorFormat unified; - col.ToUnifiedFormat(count, unified); - lhs_valid_mask.Combine(unified.validity, count); +idx_t AsOfHashGroup::InitTasks(idx_t per_thread_p) { + per_thread = per_thread_p; + + // INIT + stage_tasks.emplace_back(0); + + // SORT + auto materialize_tasks = BinValue(LeftChunks(), per_thread); + materialize_tasks += BinValue(RightChunks(), per_thread); + stage_tasks.emplace_back(materialize_tasks); + + // MATERIALIZE + stage_tasks.emplace_back(materialize_tasks); + + // GET + stage_tasks.emplace_back(materialize_tasks ? 1 : 0); + + // LEFT + const auto left_tasks = BinValue(LeftChunks(), per_thread); + stage_tasks.emplace_back(left_tasks); + + // RIGHT + const auto right_chunks = IsRightOuter() ? RightChunks() : 0; + const auto right_tasks = BinValue(right_chunks, per_thread); + stage_tasks.emplace_back(right_tasks); + + // DONE + stage_tasks.emplace_back(0); + + // Accumulate task counts so we can find boundaries reliably + idx_t begin = 0; + for (const auto &stage_task : stage_tasks) { + stage_begin.emplace_back(begin); + begin += stage_task; } - // Convert the mask to a selection vector - // and mark all the rows that cannot match for early return. - idx_t lhs_valid = 0; - const auto entry_count = lhs_valid_mask.EntryCount(count); - idx_t base_idx = 0; - left_outer.Reset(); - for (idx_t entry_idx = 0; entry_idx < entry_count;) { - const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); - const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - for (; base_idx < next; ++base_idx) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - base_idx = next; - } else { - const auto start = base_idx; - for (; base_idx < next; ++base_idx) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } + stage = AsOfJoinSourceStage(1); + + return GetTaskCount(); +} + +bool AsOfHashGroup::TryPrepareNextStage() { + switch (stage) { + case AsOfJoinSourceStage::INIT: + stage = AsOfJoinSourceStage::SORT; + return true; + case AsOfJoinSourceStage::SORT: + if (sorted >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::MATERIALIZE; + return true; } + break; + case AsOfJoinSourceStage::MATERIALIZE: + if (materialized >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::GET; + return true; + } + break; + case AsOfJoinSourceStage::GET: + if (gotten >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::LEFT; + return true; + } + break; + case AsOfJoinSourceStage::LEFT: + if (left_completed >= stage_tasks[size_t(stage)]) { + stage = stage_tasks[size_t(AsOfJoinSourceStage::RIGHT)] ? AsOfJoinSourceStage::RIGHT + : AsOfJoinSourceStage::DONE; + return true; + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_completed >= stage_tasks[size_t(stage)]) { + stage = AsOfJoinSourceStage::DONE; + return true; + } + break; + case AsOfJoinSourceStage::DONE: + return true; } - // Slice the keys to the ones we can match - lhs_payload.Reset(); - if (lhs_valid == count) { - lhs_payload.Reference(input); - lhs_payload.SetCardinality(input); - } else { - lhs_payload.Slice(input, lhs_sel, lhs_valid); - lhs_payload.SetCardinality(lhs_valid); + return false; +} + +bool AsOfHashGroup::TryNextTask(AsOfSourceTask &task) { + if (next_task >= GetTaskCount()) { + return false; + } - // Flush the ones that can't match - fetch_next_left = false; + // Search for where we are in the task list + for (idx_t stage = idx_t(AsOfJoinSourceStage::INIT); stage <= idx_t(AsOfJoinSourceStage::DONE); ++stage) { + if (next_task < stage_begin[stage]) { + task.stage = AsOfJoinSourceStage(stage - 1); + task.thread_idx = next_task - stage_begin[size_t(task.stage)]; + break; + } } - lhs_partition_sink->Sink(lhs_payload); + if (task.stage != GetStage()) { + return false; + } - return false; + task.group_idx = group_idx; + task.begin_idx = 0; + task.end_idx = 0; + + switch (task.stage) { + case AsOfJoinSourceStage::SORT: + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks() + RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + break; + case AsOfJoinSourceStage::MATERIALIZE: + if (!left_group || !right_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks() + RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::GET: + task.begin_idx = 0; + task.end_idx = 1; + task.max_idx = 1; + break; + case AsOfJoinSourceStage::LEFT: + if (left_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = LeftChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_group) { + task.begin_idx = task.thread_idx * per_thread; + task.max_idx = RightChunks(); + task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); + } + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + ++next_task; + + return true; +} + +bool AsOfHashGroup::FinishTask(AsOfSourceTask &task) { + // Inside the lock + switch (task.stage) { + case AsOfJoinSourceStage::SORT: + case AsOfJoinSourceStage::MATERIALIZE: + case AsOfJoinSourceStage::GET: + break; + case AsOfJoinSourceStage::LEFT: + if (left_completed >= stage_tasks[size_t(task.stage)]) { + left_group.reset(); + if (!IsRightOuter()) { + right_group.reset(); + } + } + break; + case AsOfJoinSourceStage::RIGHT: + if (right_completed >= stage_tasks[size_t(task.stage)]) { + right_group.reset(); + } + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + return (materialized + gotten + left_completed + right_completed) >= GetTaskCount(); } -OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { - input.Verify(); - Sink(input); +class AsOfLocalSourceState; - // If there were any unmatchable rows, return them now so we can forget about them. - if (!fetch_next_left) { - fetch_next_left = true; - left_outer.ConstructLeftJoinResult(input, chunk); - left_outer.Reset(); +class AsOfGlobalSourceState : public GlobalSourceState { +public: + using AsOfHashGroupPtr = unique_ptr; + using AsOfHashGroups = vector; + using HashedSourceStatePtr = unique_ptr; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; + using PartitionBlock = std::pair; + + AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op); + + //! Are there any more tasks? + bool HasMoreTasks() const { + return !stopped && started < total_tasks; + } + bool HasUnfinishedTasks() const { + return !stopped && finished < total_tasks; + } + + //! Assign a new task to the local state + bool TryNextTask(TaskPtr &task, Task &task_local); + + //! The parent operator + const PhysicalAsOfJoin &op; + //! The source states for the hashed sort + vector hashed_sources; + //! The hash groups + AsOfHashGroups asof_groups; + //! The sorted list of (blocks, group_idx) pairs + vector partition_blocks; + //! The ordered set of active groups + vector active_groups; + //! The next group to start + atomic next_group; + //! The total number of tasks + idx_t total_tasks = 0; + //! The number of started tasks + atomic started; + //! The number of tasks finished. + atomic finished; + //! Stop producing tasks + atomic stopped; + +public: + idx_t MaxThreads() override { + return total_tasks; + } + +protected: + //! Build task list + void CreateTaskList(ClientContext &client); + //! Finish a task + void FinishTask(TaskPtr task); +}; + +AsOfGlobalSourceState::AsOfGlobalSourceState(ClientContext &client, const PhysicalAsOfJoin &op) + : op(op), next_group(0), started(0), finished(0), stopped(false) { + // Take ownership of the hash groups + auto &gsink = op.sink_state->Cast(); + + using ChunkRow = SortStrategy::ChunkRow; + using ChunkRows = SortStrategy::ChunkRows; + vector child_groups(2); + for (idx_t child = 0; child < child_groups.size(); ++child) { + auto &sort_strategy = *gsink.sort_strategies[child]; + auto &hashed_sink = *gsink.strategy_sinks[child]; + auto hashed_source = sort_strategy.GetGlobalSourceState(client, hashed_sink); + child_groups[child] = sort_strategy.GetHashGroups(*hashed_source); + hashed_sources.emplace_back(std::move(hashed_source)); + } + + // Pivot into AsOfHashGroups + auto &lhs_groups = child_groups[0]; + auto &rhs_groups = child_groups[1]; + const auto group_count = MaxValue(lhs_groups.size(), rhs_groups.size()); + for (idx_t group_idx = 0; group_idx < group_count; ++group_idx) { + ChunkRow lhs_stats; + if (group_idx < lhs_groups.size()) { + lhs_stats = lhs_groups[group_idx]; + } + ChunkRow rhs_stats; + if (group_idx < rhs_groups.size()) { + rhs_stats = rhs_groups[group_idx]; + } + auto asof_group = make_uniq(op, lhs_stats, rhs_stats, group_idx); + asof_groups.emplace_back(std::move(asof_group)); } - // Just keep asking for data and buffering it - return OperatorResultType::NEED_MORE_INPUT; + CreateTaskList(client); } -OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &lstate_p) const { - auto &gsink = sink_state->Cast(); - auto &lstate = lstate_p.Cast(); +void AsOfGlobalSourceState::CreateTaskList(ClientContext &client) { + // Sort the groups from largest to smallest + if (asof_groups.empty()) { + return; + } - if (gsink.rhs_sink.count == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; + // Count chunks, not rows (otherwise left and right raggedness could give the wrong answer + for (idx_t group_idx = 0; group_idx < asof_groups.size(); ++group_idx) { + auto &asof_hash_group = asof_groups[group_idx]; + if (!asof_hash_group) { + continue; } + partition_blocks.emplace_back(asof_hash_group->MaximumChunks(), group_idx); } + std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); + const auto &max_block = partition_blocks.front(); - return lstate.ExecuteInternal(context, input, chunk); + // Schedule the largest group on as many threads as possible + auto &ts = TaskScheduler::GetScheduler(client); + const auto threads = NumericCast(ts.NumberOfThreads()); + + const auto per_thread = AsOfHashGroup::BinValue(max_block.first, threads); + if (!per_thread) { + throw InternalException("No blocks per AsOf thread! %ld threads, %ld groups, %ld blocks, %ld hash group", + threads, partition_blocks.size(), max_block.first, max_block.second); + } + + for (const auto &b : partition_blocks) { + total_tasks += asof_groups[b.second]->InitTasks(per_thread); + } } -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// +enum class SortKeyPrefixComparisonType : uint8_t { FIXED, VARCHAR, NESTED }; + +struct SortKeyPrefixComparisonColumn { + SortKeyPrefixComparisonType type; + idx_t size; +}; + +struct SortKeyPrefixComparisonResult { + //! The column at which the sides are no longer equal, + //! e.g., Compare([42, 84], [42, 83]) would return {1, COMPARE_GREATERTHAN} + idx_t column_index; + //! Either COMPARE_EQUAL, COMPARE_LESSTHAN, COMPARE_GREATERTHAN + ExpressionType type; +}; + +struct SortKeyPrefixComparison { + unsafe_vector columns; + //! Two row buffer for measuring lhs and rhs widths for nested types. + //! Gross, but there is currently no way to measure the width of a single key + //! except as a side-effect of decoding it... + DataChunk decoded; + + template + SortKeyPrefixComparisonResult Compare(const SORT_KEY &lhs, const SORT_KEY &rhs) { + SortKeyPrefixComparisonResult result {0, ExpressionType::COMPARE_EQUAL}; + + auto lhs_copy = lhs; + string_t lhs_key; + lhs_copy.Deconstruct(lhs_key); + auto lhs_ptr = lhs_key.GetData(); + + auto rhs_copy = rhs; + string_t rhs_key; + rhs_copy.Deconstruct(rhs_key); + auto rhs_ptr = rhs_key.GetData(); + + // Partition keys are always sorted this way. + OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST); + + for (column_t col_idx = 0; col_idx < columns.size(); ++col_idx) { + const auto &col = columns[col_idx]; + auto &vec = decoded.data[col_idx]; + auto lhs_width = col.size; + auto rhs_width = col.size; + int cmp = 1; + switch (col.type) { + case SortKeyPrefixComparisonType::FIXED: + cmp = memcmp(lhs_ptr, rhs_ptr, lhs_width); + break; + case SortKeyPrefixComparisonType::VARCHAR: + // Include first null byte. + lhs_width = 1 + strlen(lhs_ptr); + rhs_width = 1 + strlen(rhs_ptr); + cmp = memcmp(lhs_ptr, rhs_ptr, MinValue(lhs_width, rhs_width)); + break; + case SortKeyPrefixComparisonType::NESTED: + decoded.Reset(); + lhs_width = CreateSortKeyHelpers::DecodeSortKey(lhs_key, vec, 0, modifiers); + rhs_width = CreateSortKeyHelpers::DecodeSortKey(rhs_key, vec, 1, modifiers); + cmp = memcmp(lhs_ptr, rhs_ptr, MinValue(lhs_width, rhs_width)); + if (!cmp) { + cmp = (rhs_width < lhs_width) - (lhs_width < rhs_width); + } + break; + } + + if (cmp) { + result.type = (cmp < 0) ? ExpressionType::COMPARE_LESSTHAN : ExpressionType::COMPARE_GREATERTHAN; + return result; + } + + ++result.column_index; + lhs_ptr += lhs_width; + rhs_ptr += rhs_width; + } + + return result; + } +}; + class AsOfProbeBuffer { public: using Orders = vector; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; - static bool IsExternal(ClientContext &context) { - return ClientConfig::GetConfig(context).force_external; - } - - AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); + AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource); public: - void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); - bool Scanning() const { - return lhs_scanner.get(); + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw NotImplementedException("Unsupported comparison type for ASOF join"); + } } - void BeginLeftScan(hash_t scan_bin); + + //! Is left cmp right? + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); + } + return less_than; + } + + template + void ResolveJoin(idx_t *matches); + + using resolve_join_t = void (duckdb::AsOfProbeBuffer::*)(idx_t *); + resolve_join_t resolve_join_func; + + void BeginLeftScan(TaskPtr task); bool NextLeft(); + void ScanLeft(); void EndLeftScan(); + //! Create a new iterator for the sorted run + static unique_ptr CreateIteratorState(SortedRun &sorted) { + auto state = make_uniq(*sorted.key_data, sorted.payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + // Don't pin the payload because we are not using it here. + iter.SetKeepPinned(true); + } // resolve joins that output max N elements (SEMI, ANTI, MARK) void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) + // resolve joins that can potentially output N*M elements (LEFT, LEFT, FULL) void ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk); // Chunk may be empty void GetData(ExecutionContext &context, DataChunk &chunk); bool HasMoreData() const { - return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); + return !fetch_next_left || (task->begin_idx < task->end_idx); } - ClientContext &context; - Allocator &allocator; + ClientContext &client; const PhysicalAsOfJoin &op; - BufferManager &buffer_manager; - const bool force_external; - const idx_t memory_per_thread; - Orders lhs_orders; + //! The source state + AsOfGlobalSourceState &gsource; + //! Is the inequality strict? + const bool strict; + //! The current hash group + optional_ptr asof_hash_group; + //! The task we are processing + TaskPtr task; // LHS scanning - SelectionVector lhs_sel; - optional_ptr left_hash; + optional_ptr left_group; OuterJoinMarker left_outer; - unique_ptr left_itr; - unique_ptr lhs_scanner; + unique_ptr left_itr; + unique_ptr lhs_scanner; DataChunk lhs_payload; - idx_t left_group = 0; + ExpressionExecutor lhs_executor; + DataChunk lhs_keys; + ValidityMask lhs_valid_mask; + idx_t left_bin = 0; + SelectionVector lhs_match_sel; // RHS scanning - optional_ptr right_hash; + optional_ptr right_group; optional_ptr right_outer; - unique_ptr right_itr; - unique_ptr rhs_scanner; + unique_ptr right_itr; + idx_t right_pos; // ExternalBlockIteratorState doesn't know this... + unique_ptr rhs_scanner; DataChunk rhs_payload; - idx_t right_group = 0; + DataChunk rhs_input; + SelectionVector rhs_match_sel; + idx_t right_bin = 0; // Predicate evaluation - SelectionVector filter_sel; - ExpressionExecutor filterer; - idx_t lhs_match_count; bool fetch_next_left; + + SortKeyPrefixComparison prefix; }; -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), - memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), filterer(context), - fetch_next_left(true) { - vector> partition_stats; - Orders partitions; // Not used. - PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, - partition_stats); - - // We sort the row numbers of the incoming block, not the rows - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - rhs_payload.Initialize(allocator, op.children[1].get().GetTypes()); - - lhs_sel.Initialize(); +AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op, AsOfGlobalSourceState &gsource) + : client(client), op(op), gsource(gsource), strict(IsStrictComparison(op.comparison_type)), + left_outer(IsLeftOuterJoin(op.join_type)), lhs_executor(client), fetch_next_left(true) { + lhs_keys.Initialize(client, op.join_key_types); + for (const auto &cond : op.conditions) { + lhs_executor.AddExpression(*cond.left); + } + + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + rhs_payload.Initialize(client, op.children[1].get().GetTypes()); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); + + lhs_match_sel.Initialize(); + rhs_match_sel.Initialize(); left_outer.Initialize(STANDARD_VECTOR_SIZE); - if (op.predicate) { - filter_sel.Initialize(); - filterer.AddExpression(*op.predicate); + // If we have equality predicates, set up the prefix data. + vector prefix_types; + for (idx_t i = 0; i < op.conditions.size() - 1; ++i) { + const auto &cond = op.conditions[i]; + const auto &type = cond.left->return_type; + prefix_types.emplace_back(type); + SortKeyPrefixComparisonColumn col; + col.size = DConstants::INVALID_INDEX; + switch (type.id()) { + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + col.type = SortKeyPrefixComparisonType::VARCHAR; + break; + case LogicalTypeId::STRUCT: + case LogicalTypeId::LIST: + case LogicalTypeId::ARRAY: + col.type = SortKeyPrefixComparisonType::NESTED; + break; + default: + col.type = SortKeyPrefixComparisonType::FIXED; + col.size = 1 + GetTypeIdSize(type.InternalType()); + break; + } + prefix.columns.emplace_back(col); + } + if (!prefix_types.empty()) { + // LHS, RHS + prefix.decoded.Initialize(client, prefix_types, 2); } } -void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { +void AsOfProbeBuffer::BeginLeftScan(TaskPtr task_p) { auto &gsink = op.sink_state->Cast(); + task = task_p; + const auto scan_bin = task->group_idx; - auto &lhs_sink = *gsink.lhs_sink; - left_group = lhs_sink.bin_groups[scan_bin]; + asof_hash_group = gsource.asof_groups[scan_bin].get(); - // Always set right_group too for memory management - auto &rhs_sink = gsink.rhs_sink; - if (scan_bin < rhs_sink.bin_groups.size()) { - right_group = rhs_sink.bin_groups[scan_bin]; - } else { - right_group = rhs_sink.bin_groups.size(); - } + // Always set right_bin too for memory management + right_group = asof_hash_group->right_group; + right_bin = right_group ? scan_bin : gsource.asof_groups.size(); - if (left_group >= lhs_sink.bin_groups.size()) { + left_group = asof_hash_group->left_group; + left_bin = left_group ? scan_bin : gsource.asof_groups.size(); + if (!left_group || !left_group->Count()) { return; } - auto iterator_comp = ExpressionType::INVALID; - switch (op.comparison_type) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; + // Set up function pointer for sort type + const auto sort_key_type = left_group->key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_GREATERTHAN: - iterator_comp = ExpressionType::COMPARE_LESSTHAN; + case SortKeyType::NO_PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + case SortKeyType::NO_PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHAN: - iterator_comp = ExpressionType::COMPARE_GREATERTHAN; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; default: throw NotImplementedException("Unsupported comparison type for ASOF join"); } - left_hash = lhs_sink.hash_groups[left_group].get(); - auto &left_sort = *(left_hash->global_sort); - if (left_sort.sorted_blocks.empty()) { - return; - } - lhs_scanner = make_uniq(left_sort, false); - left_itr = make_uniq(left_sort, iterator_comp); + lhs_scanner = make_uniq(*left_group, *gsink.sort_strategies[0]); + lhs_scanner->SeekBlock(task->begin_idx); + left_itr = CreateIteratorState(*left_group); // We are only probing the corresponding right side bin, which may be empty - // If they are empty, we leave the iterator as null so we can emit left matches - if (right_group < rhs_sink.bin_groups.size()) { - right_hash = rhs_sink.hash_groups[right_group].get(); - right_outer = gsink.right_outers.data() + right_group; - auto &right_sort = *(right_hash->global_sort); - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); + // If it is empty, we leave the iterator as null so we can emit left matches + right_pos = 0; + if (right_group) { + right_outer = &asof_hash_group->right_outer; + if (right_group && right_group->Count()) { + right_itr = CreateIteratorState(*right_group); + rhs_scanner = make_uniq(*right_group, *gsink.sort_strategies[1]); + } } } bool AsOfProbeBuffer::NextLeft() { - if (!HasMoreData()) { - return false; - } + return task->begin_idx < task->end_idx; +} +void AsOfProbeBuffer::ScanLeft() { // Scan the next sorted chunk lhs_payload.Reset(); - left_itr->SetIndex(lhs_scanner->Scanned()); lhs_scanner->Scan(lhs_payload); + ++task->begin_idx; - return true; + // Compute the join keys + lhs_keys.Reset(); + lhs_executor.Execute(lhs_payload, lhs_keys); + lhs_keys.Flatten(); + + // Combine the NULLs + const auto count = lhs_payload.size(); + lhs_valid_mask.Reset(); + for (auto col_idx : op.null_sensitive) { + auto &col = lhs_keys.data[col_idx]; + UnifiedVectorFormat unified; + col.ToUnifiedFormat(count, unified); + lhs_valid_mask.Combine(unified.validity, count); + } + + // Filter out NULL matches + if (!lhs_valid_mask.AllValid()) { + const auto count = lhs_match_count; + lhs_match_count = 0; + for (idx_t i = 0; i < count; ++i) { + const auto idx = lhs_match_sel.get_index(i); + if (lhs_valid_mask.RowIsValidUnsafe(idx)) { + lhs_match_sel.set_index(lhs_match_count++, idx); + } + } + } } void AsOfProbeBuffer::EndLeftScan() { - auto &gsink = op.sink_state->Cast(); + if (task->stage != AsOfJoinSourceStage::LEFT) { + return; + } + task->stage = AsOfJoinSourceStage::DONE; - right_hash = nullptr; + D_ASSERT(asof_hash_group); + asof_hash_group->left_completed++; + + right_group = nullptr; right_itr.reset(); rhs_scanner.reset(); right_outer = nullptr; - auto &rhs_sink = gsink.rhs_sink; - if (!gsink.is_outer && right_group < rhs_sink.bin_groups.size()) { - rhs_sink.hash_groups[right_group].reset(); - } - - left_hash = nullptr; + left_group = nullptr; left_itr.reset(); lhs_scanner.reset(); - - auto &lhs_sink = *gsink.lhs_sink; - if (left_group < lhs_sink.bin_groups.size()) { - lhs_sink.hash_groups[left_group].reset(); - } } -void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { +template +void AsOfProbeBuffer::ResolveJoin(idx_t *matches) { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + // If there was no right partition, there are no matches lhs_match_count = 0; if (!right_itr) { return; } - const auto count = lhs_payload.size(); - const auto left_base = left_itr->GetIndex(); + Repin(*left_itr); + BLOCKS_ITERATOR left_key(*left_itr); + + Repin(*right_itr); + BLOCKS_ITERATOR right_key(*right_itr); + + const auto count = lhs_scanner->NextSize(); + const auto left_base = lhs_scanner->Scanned(); // Searching for right <= left for (idx_t i = 0; i < count; ++i) { - left_itr->SetIndex(left_base + i); - // If right > left, then there is no match - if (!right_itr->Compare(*left_itr)) { + const auto left_pos = left_base + i; + if (!Compare(right_key[right_pos], left_key[left_pos], strict)) { continue; } // Exponential search forward for a non-matching value using radix iterators // (We use exponential search to avoid thrashing the block manager on large probes) idx_t bound = 1; - idx_t begin = right_itr->GetIndex(); - right_itr->SetIndex(begin + bound); - while (right_itr->GetIndex() < right_hash->count) { - if (right_itr->Compare(*left_itr)) { + idx_t begin = right_pos; + while (begin + bound < right_group->Count()) { + if (Compare(right_key[begin + bound], left_key[left_pos], strict)) { // If right <= left, jump ahead bound *= 2; - right_itr->SetIndex(begin + bound); } else { break; } @@ -539,43 +1112,46 @@ void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { // Binary search for the first non-matching value using radix iterators // The previous value (which we know exists) is the match auto first = begin + bound / 2; - auto last = MinValue(begin + bound, right_hash->count); + auto last = MinValue(begin + bound, right_group->Count()); while (first < last) { const auto mid = first + (last - first) / 2; - right_itr->SetIndex(mid); - if (right_itr->Compare(*left_itr)) { + if (Compare(right_key[mid], left_key[left_pos], strict)) { // If right <= left, new lower bound first = mid + 1; } else { last = mid; } } - right_itr->SetIndex(--first); + right_pos = --first; // Check partitions for strict equality - if (right_hash->ComparePartitions(*left_itr, *right_itr)) { - continue; + if (!prefix.columns.empty()) { + const auto cmp = prefix.Compare(left_key[left_pos], right_key[right_pos]); + if (cmp.column_index < prefix.columns.size()) { + continue; + } } // Emit match data - if (found_match) { - found_match[i] = true; - } if (matches) { matches[i] = first; } - lhs_sel.set_index(lhs_match_count++, i); + lhs_match_sel.set_index(lhs_match_count++, i); } } -unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join + (this->*resolve_join_func)(nullptr); + + // Scan the lhs values (after comparing keys) and filter out the LHS NULLs + ScanLeft(); + + // Convert the match selection to simple join mask bool found_match[STANDARD_VECTOR_SIZE] = {false}; - ResolveJoin(found_match); + for (idx_t i = 0; i < lhs_match_count; ++i) { + found_match[lhs_match_sel.get_index(i)] = true; + } // now construct the result based on the join result switch (op.join_type) { @@ -593,43 +1169,51 @@ void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &ch void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join idx_t matches[STANDARD_VECTOR_SIZE]; - ResolveJoin(nullptr, matches); + (this->*resolve_join_func)(matches); + // Scan the lhs values (after comparing keys) and filter out the LHS NULLs + ScanLeft(); + + // Extract the rhs input columns from the match + rhs_input.Reset(); + idx_t rhs_match_count = 0; for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = lhs_sel[i]; + const auto idx = lhs_match_sel[i]; const auto match_pos = matches[idx]; // Skip to the range containing the match - while (match_pos >= rhs_scanner->Scanned()) { + if (match_pos >= rhs_scanner->Scanned()) { + if (rhs_match_count) { + rhs_input.Append(rhs_payload, false, &rhs_match_sel, rhs_match_count); + rhs_match_count = 0; + } rhs_payload.Reset(); + rhs_scanner->SeekRow(match_pos); rhs_scanner->Scan(rhs_payload); } - // Append the individual values - // TODO: Batch the copies - const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); - for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { - const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_payload.data[rhs_idx]; - auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; - VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); - } + // Select the individual values + const auto source_offset = match_pos - rhs_scanner->Base(); + rhs_match_sel.set_index(rhs_match_count++, source_offset); } + rhs_input.Append(rhs_payload, false, &rhs_match_sel, rhs_match_count); // Slice the left payload into the result for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { - chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); + chunk.data[i].Slice(lhs_payload.data[i], lhs_match_sel, lhs_match_count); } - chunk.SetCardinality(lhs_match_count); - auto match_sel = &lhs_sel; - if (filterer.expressions.size() == 1) { - lhs_match_count = filterer.SelectExpression(chunk, filter_sel); - chunk.Slice(filter_sel, lhs_match_count); - match_sel = &filter_sel; + + // Reference the projected right payload into the result + for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + auto &source = rhs_input.data[rhs_idx]; + auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; + target.Reference(source); } + chunk.SetCardinality(lhs_match_count); // Update the match masks for the rows we ended up with left_outer.Reset(); for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = match_sel->get_index(i); + const auto idx = lhs_match_sel.get_index(i); left_outer.SetMatch(idx); const auto first = matches[idx]; right_outer->SetMatch(first); @@ -675,241 +1259,415 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { } } -class AsOfGlobalSourceState : public GlobalSourceState { -public: - explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { - } - - PartitionGlobalMergeStates &GetMergeStates() { - lock_guard guard(lock); - if (!merge_states) { - merge_states = make_uniq(*gsink.lhs_sink); - } - return *merge_states; - } - - AsOfGlobalSinkState &gsink; - //! The next buffer to combine - atomic next_combine; - //! The number of combined buffers - atomic combined; - //! The number of combined buffers - atomic merged; - //! The number of combined buffers - atomic mergers; - //! The next buffer to flush - atomic next_left; - //! The number of flushed buffers - atomic flushed; - //! The right outer output read position. - atomic next_right; - //! The merge handler - mutex lock; - unique_ptr merge_states; - -public: - idx_t MaxThreads() override { - return gsink.lhs_buffers.size(); - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); +unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &client) const { + return make_uniq(client, *this); } class AsOfLocalSourceState : public LocalSourceState { public: - using HashGroupPtr = unique_ptr; - - AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); - - // Return true if we were not interrupted (another thread died) - bool CombineLeftPartitions(); - bool MergeLeftPartitions(); + using HashGroupPtr = optional_ptr; + using Task = AsOfSourceTask; + using TaskPtr = optional_ptr; + + AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op); + + //! Task management + bool TaskFinished() const; + //! Assign the next task + bool TryAssignTask(); + + void ExecuteSortTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteMaterializeTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteGetTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteLeftTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + void ExecuteRightTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source); + + void ExecuteTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + ExecuteSortTask(context, chunk, source); + break; + case AsOfJoinSourceStage::MATERIALIZE: + ExecuteMaterializeTask(context, chunk, source); + break; + case AsOfJoinSourceStage::GET: + ExecuteGetTask(context, chunk, source); + break; + case AsOfJoinSourceStage::LEFT: + ExecuteLeftTask(context, chunk, source); + break; + case AsOfJoinSourceStage::RIGHT: + ExecuteRightTask(context, chunk, source); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + throw InternalException("Invalid state for AsOf Task"); + } + } - idx_t BeginRightScan(const idx_t hash_bin); + void BeginRightScan(); + void EndRightScan(); AsOfGlobalSourceState &gsource; - ClientContext &client; + ExecutionContext &context; //! The left side partition being probed AsOfProbeBuffer probe_buffer; - //! The read partition - idx_t hash_bin; + //! The task this thread is working on + TaskPtr task; + //! The task storage + Task task_local; + //! The rhs group HashGroupPtr hash_group; //! The read cursor - unique_ptr scanner; - //! Pointer to the matches - const bool *found_match = {}; + unique_ptr scanner; + //! The right outer buffer + DataChunk rhs_chunk; + //! The right outer slicer + SelectionVector rsel; + //! Pointer to the right marker + const bool *rhs_matches = {}; }; -AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, - ClientContext &client_p) - : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { - gsource.mergers++; +AsOfLocalSourceState::AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, + const PhysicalAsOfJoin &op) + : gsource(gsource), context(context), probe_buffer(context.client, op, gsource), rsel(STANDARD_VECTOR_SIZE) { + rhs_chunk.Initialize(context.client, op.children[1].get().GetTypes()); } -bool AsOfLocalSourceState::CombineLeftPartitions() { - const auto buffer_count = gsource.gsink.lhs_buffers.size(); - while (gsource.combined < buffer_count && !client.interrupted) { - const auto next_combine = gsource.next_combine++; - if (next_combine < buffer_count) { - gsource.gsink.lhs_buffers[next_combine]->Combine(); - ++gsource.combined; - } else { - TaskScheduler::GetScheduler(client).YieldThread(); - } +bool AsOfLocalSourceState::TaskFinished() const { + if (!task) { + return true; } - return !client.interrupted; -} - -bool AsOfLocalSourceState::MergeLeftPartitions() { - PartitionGlobalMergeStates::Callback local_callback; - PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); - gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); - gsource.merged++; - while (gsource.merged < gsource.mergers && !client.interrupted) { - TaskScheduler::GetScheduler(client).YieldThread(); + if (task->stage == AsOfJoinSourceStage::LEFT && !probe_buffer.fetch_next_left) { + return false; } - return !client.interrupted; + + return task->begin_idx >= task->end_idx; } -idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { - hash_bin = hash_bin_p; +void AsOfLocalSourceState::BeginRightScan() { + const auto hash_bin = task->group_idx; - hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); - if (hash_group->global_sort->sorted_blocks.empty()) { - return 0; + auto &asof_groups = gsource.asof_groups; + if (hash_bin >= asof_groups.size()) { + return; + } + + hash_group = asof_groups[hash_bin]->right_group.get(); + if (!hash_group || !hash_group->Count()) { + return; } - scanner = make_uniq(*hash_group->global_sort); - found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); + auto &gsink = gsource.op.sink_state->Cast(); + scanner = make_uniq(*hash_group, *gsink.sort_strategies[1]); + scanner->SeekBlock(task->begin_idx); - return scanner->Remaining(); + rhs_matches = asof_groups[hash_bin]->right_outer.GetMatches(); +} + +void AsOfLocalSourceState::EndRightScan() { + D_ASSERT(task->stage == AsOfJoinSourceStage::RIGHT); + + auto &asof_groups = gsource.asof_groups; + const auto hash_bin = task->group_idx; + const auto &asof_hash_group = asof_groups[hash_bin]; + asof_hash_group->right_completed++; } unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { auto &gsource = gstate.Cast(); - return make_uniq(gsource, *this, context.client); + return make_uniq(context, gsource, *this); } -SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - auto &rhs_sink = gsource.gsink.rhs_sink; - auto &client = context.client; - - // Step 1: Combine the partitions - if (!lsource.CombineLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 2: Sort on all threads - if (!lsource.MergeLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 3: Join the partitions - auto &lhs_sink = *gsource.gsink.lhs_sink; - const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; - while (gsource.flushed < left_bins) { - // Make sure we have something to flush - if (!lsource.probe_buffer.Scanning()) { - const auto left_bin = gsource.next_left++; - if (left_bin < left_bins) { - // More to flush - lsource.probe_buffer.BeginLeftScan(left_bin); - } else if (!IsRightOuterJoin(join_type) || client.interrupted) { - return SourceResultType::FINISHED; - } else { - // Wait for all threads to finish - // TODO: How to implement a spin wait correctly? - // Returning BLOCKED seems to hang the system. - TaskScheduler::GetScheduler(client).YieldThread(); - continue; - } +void AsOfGlobalSourceState::FinishTask(TaskPtr task) { + // Inside the lock + if (!task) { + return; + } + + ++finished; + + const auto group_idx = task->group_idx; + auto &finished_hash_group = asof_groups[group_idx]; + D_ASSERT(finished_hash_group); + + if (finished_hash_group->FinishTask(*task)) { + // Remove it from the active groups + auto &v = active_groups; + v.erase(std::remove(v.begin(), v.end(), group_idx), v.end()); + } +} + +bool AsOfLocalSourceState::TryAssignTask() { + D_ASSERT(TaskFinished()); + // Because downstream operators may be using our internal buffers, + // we can't "finish" a task until we are about to get the next one. + if (task) { + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + gsource.asof_groups[task_local.group_idx]->sorted++; + break; + case AsOfJoinSourceStage::MATERIALIZE: + gsource.asof_groups[task_local.group_idx]->materialized++; + break; + case AsOfJoinSourceStage::GET: + gsource.asof_groups[task_local.group_idx]->gotten++; + break; + case AsOfJoinSourceStage::LEFT: + probe_buffer.EndLeftScan(); + break; + case AsOfJoinSourceStage::RIGHT: + EndRightScan(); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; } + } - lsource.probe_buffer.GetData(context, chunk); - if (chunk.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (lsource.probe_buffer.HasMoreData()) { - // Join the next partition - continue; - } else { - lsource.probe_buffer.EndLeftScan(); - gsource.flushed++; + if (!gsource.TryNextTask(task, task_local)) { + return false; + } + + switch (task->stage) { + case AsOfJoinSourceStage::SORT: + case AsOfJoinSourceStage::MATERIALIZE: + case AsOfJoinSourceStage::GET: + break; + case AsOfJoinSourceStage::LEFT: + probe_buffer.BeginLeftScan(*task); + break; + case AsOfJoinSourceStage::RIGHT: + BeginRightScan(); + break; + case AsOfJoinSourceStage::INIT: + case AsOfJoinSourceStage::DONE: + break; + } + + return true; +} + +bool AsOfGlobalSourceState::TryNextTask(TaskPtr &task, Task &task_local) { + auto guard = Lock(); + FinishTask(task); + + if (!HasMoreTasks()) { + task = nullptr; + return false; + } + + // Run through the active groups looking for one that can assign a task + for (const auto &group_idx : active_groups) { + auto &asof_group = asof_groups[group_idx]; + if (asof_group->TryPrepareNextStage()) { + UnblockTasks(guard); + } + if (asof_group->TryNextTask(task_local)) { + task = task_local; + ++started; + return true; } } - // Step 4: Emit right join matches - if (!IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; + // All active groups are busy or blocked, so start the next one (if any) + while (next_group < partition_blocks.size()) { + const auto group_idx = partition_blocks[next_group++].second; + active_groups.emplace_back(group_idx); + + auto &asof_group = asof_groups[group_idx]; + if (asof_group->TryPrepareNextStage()) { + UnblockTasks(guard); + } + if (!asof_group->TryNextTask(task_local)) { + // Group has no tasks (empty?) + continue; + } + + task = task_local; + ++started; + return true; } - auto &hash_groups = rhs_sink.hash_groups; - const auto right_groups = hash_groups.size(); + task = nullptr; - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner || !lsource.scanner->Remaining()) { - lsource.scanner.reset(); - lsource.hash_group.reset(); - auto hash_bin = gsource.next_right++; - if (hash_bin >= right_groups) { - return SourceResultType::FINISHED; + return false; +} + +void AsOfLocalSourceState::ExecuteSortTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + // Left or right? + const idx_t child = task_local.begin_idx >= asof_group.LeftChunks(); + const auto &gsink = gsource.op.sink_state->Cast(); + auto &sort_strategy = *gsink.sort_strategies[child]; + auto &hashed_sink = *gsink.strategy_sinks[child]; + + OperatorSinkFinalizeInput finalize {hashed_sink, source.interrupt_state}; + sort_strategy.SortColumnData(context, task_local.group_idx, finalize); + + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteMaterializeTask(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + // Left or right? + const idx_t child = task_local.begin_idx >= asof_group.LeftChunks(); + const auto &gsink = gsource.op.sink_state->Cast(); + auto &sort_strategy = *gsink.sort_strategies[child]; + auto &hashed_source = *gsource.hashed_sources[child]; + + auto unused = make_uniq(); + OperatorSourceInput hsource {hashed_source, *unused, source.interrupt_state}; + sort_strategy.MaterializeSortedRun(context, task_local.group_idx, hsource); + + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteGetTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + auto &asof_group = *gsource.asof_groups[task_local.group_idx]; + + const auto &gsink = gsource.op.sink_state->Cast(); + auto unused = make_uniq(); + + for (idx_t child = 0; child < gsink.sort_strategies.size(); ++child) { + // Don't get children that don't exist + if (child) { + if (!asof_group.RightChunks()) { + continue; + } + } else { + if (!asof_group.LeftChunks()) { + continue; } + } - for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { - if (hash_groups[hash_bin]) { - break; - } + auto &sort_strategy = *gsink.sort_strategies[child]; + auto &hashed_source = *gsource.hashed_sources[child]; + OperatorSourceInput hsource {hashed_source, *unused, source.interrupt_state}; + + auto group = sort_strategy.GetSortedRun(context.client, task_local.group_idx, hsource); + if (group) { + if (child) { + asof_group.right_group = std::move(group); + } else { + asof_group.left_group = std::move(group); } - lsource.BeginRightScan(hash_bin); } - const auto rhs_position = lsource.scanner->Scanned(); - lsource.scanner->Scan(rhs_chunk); + } - const auto count = rhs_chunk.size(); - if (count == 0) { - return SourceResultType::FINISHED; + // Mark this range as done + task->begin_idx = task->end_idx; +} + +void AsOfLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &source) { + while (probe_buffer.HasMoreData()) { + probe_buffer.GetData(context, chunk); + if (chunk.size()) { + return; } + } +} + +SourceResultType PhysicalAsOfJoin::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); + + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.HasUnfinishedTasks() && chunk.size() == 0) { + if (!lsource.TaskFinished() || lsource.TryAssignTask()) { + try { + lsource.ExecuteTask(context, chunk, input); + } catch (...) { + gsource.stopped = true; + throw; + } + } else { + auto guard = gsource.Lock(); + if (!gsource.HasMoreTasks()) { + gsource.UnblockTasks(guard); + } else { + // there are more tasks available, but we can't execute them yet + // block the source + return gsource.BlockSource(guard, input.interrupt_state); + } + } + } + + return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; +} + +void AsOfLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) { + while (task->begin_idx < task->end_idx) { + const auto rhs_position = scanner->Scanned(); + scanner->Scan(rhs_chunk); + ++task->begin_idx; // figure out which tuples didn't find a match in the RHS - auto found_match = lsource.found_match; + const auto count = rhs_chunk.size(); idx_t result_count = 0; for (idx_t i = 0; i < count; i++) { - if (!found_match[rhs_position + i]) { + if (!rhs_matches[rhs_position + i]) { rsel.set_index(result_count++, i); } } + if (!result_count) { + continue; + } - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0].get().GetTypes().size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { - const auto rhs_idx = right_projection_map[col_idx]; - chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); - } - chunk.SetCardinality(result_count); - break; + // if there were any tuples that didn't find a match, output them + const auto &op = gsource.op; + const idx_t left_column_count = op.children[0].get().GetTypes().size(); + for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } + for (idx_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); } + chunk.SetCardinality(result_count); + return; } - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; + // Exhausted the task data + scanner.reset(); +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalAsOfJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + D_ASSERT(children.size() == 2); + if (meta_pipeline.HasRecursiveCTE()) { + throw NotImplementedException("AsOf joins are not supported in recursive CTEs yet"); + } + + // becomes a source after both children fully sink their data + meta_pipeline.GetState().SetPipelineSource(current, *this); + + // Create one child meta pipeline that will hold the LHS and RHS pipelines + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + + // Build out RHS first because that is the order the join planner expects. + auto rhs_pipeline = child_meta_pipeline.GetBasePipeline(); + children[1].get().BuildPipelines(*rhs_pipeline, child_meta_pipeline); + + // Build out LHS + auto &lhs_pipeline = child_meta_pipeline.CreatePipeline(); + children[0].get().BuildPipelines(lhs_pipeline, child_meta_pipeline); + + // Despite having the same sink, LHS and everything created after it need their own (same) PipelineFinishEvent + child_meta_pipeline.AddFinishEvent(lhs_pipeline); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp index 0b6585b7b..b6431d1ad 100644 --- a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp @@ -261,8 +261,8 @@ unique_ptr PhysicalBlockwiseNLJoin::GetLocalSourceState(Execut return make_uniq(*this, gstate.Cast()); } -SourceResultType PhysicalBlockwiseNLJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalBlockwiseNLJoin::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { D_ASSERT(PropagatesBuildSide(join_type)); // check if we need to scan any unmatched tuples from the RHS for the full/right outer join auto &sink = sink_state->Cast(); diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index 9513bded8..12ec306b8 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -20,6 +20,7 @@ #include "duckdb/planner/filter/constant_filter.hpp" #include "duckdb/planner/filter/in_filter.hpp" #include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/temporary_memory_manager.hpp" @@ -36,7 +37,6 @@ PhysicalHashJoin::PhysicalHashJoin(PhysicalPlan &physical_plan, LogicalOperator : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), delim_types(std::move(delim_types)) { - filter_pushdown = std::move(pushdown_info_p); children.push_back(left); @@ -164,6 +164,11 @@ class HashJoinGlobalSinkState : public GlobalSinkState { } } + ~HashJoinGlobalSinkState() override { + DUCKDB_LOG(context, PhysicalOperatorLogType, op, "PhysicalHashJoin", "GetData", + {{"total_probe_matches", to_string(hash_table->total_probe_matches)}}); + } + void ScheduleFinalize(Pipeline &pipeline, Event &event); void InitializeProbeSpill(); @@ -283,7 +288,7 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c auto count_fun = CountFunctionBase::GetFunction(); vector> children; // this is a dummy but we need it to make the hash table understand whats going on - children.push_back(make_uniq_base(count_fun.return_type, 0U)); + children.push_back(make_uniq_base(count_fun.GetReturnType(), 0U)); aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, AggregateType::NON_DISTINCT); correlated_aggregates.push_back(&*aggr); @@ -391,7 +396,6 @@ static bool KeysAreSkewed(const HashJoinGlobalSinkState &sink) { //! If we have only one thread, always finalize single-threaded. Otherwise, we finalize in parallel if we //! have more than 1M rows or if we want to verify parallelism. static bool FinalizeSingleThreaded(const HashJoinGlobalSinkState &sink, const bool consider_skew) { - // if only one thread, finalize single-threaded const auto num_threads = NumericCast(sink.num_threads); if (num_threads == 1) { @@ -701,26 +705,23 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { } }; +bool JoinFilterPushdownInfo::CanUseInFilter(const ClientContext &context, optional_ptr ht, + const ExpressionType &cmp) const { + auto dynamic_or_filter_threshold = DBConfig::GetSetting(context); + return ht && ht->Count() > 1 && ht->Count() <= dynamic_or_filter_threshold && cmp == ExpressionType::COMPARE_EQUAL; +} + void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, idx_t filter_col_idx) const { // generate a "OR" filter (i.e. x=1 OR x=535 OR x=997) // first scan the entire vector at the probe side - // FIXME: this code is duplicated from PerfectHashJoinExecutor::FullScanHashTable auto build_idx = join_condition[filter_idx]; - auto &data_collection = ht.GetDataCollection(); - Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples - - JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); - - // Go through all the blocks and fill the keys addresses - idx_t key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); - - // Scan the build keys in the hash table - Vector build_vector(ht.layout_ptr->GetTypes()[build_idx], key_count); - data_collection.Gather(tuples_addresses, *FlatVector::IncrementalSelectionVector(), key_count, build_idx, - build_vector, *FlatVector::IncrementalSelectionVector(), nullptr); + Vector build_vector(ht.layout_ptr->GetTypes()[build_idx], ht.Count()); + auto key_count = ht.ScanKeyColumn(tuples_addresses, build_vector, build_idx); + if (key_count == 0) { + return; + } // generate the OR-clause - note that we only need to consider unique values here (so we use a seT) value_set_t unique_ht_values; @@ -743,12 +744,49 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, // the IN-list is expensive to execute otherwise auto filter = make_uniq(std::move(in_filter)); info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); - return; } -unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, optional_ptr ht, - JoinFilterGlobalState &gstate, - const PhysicalComparisonJoin &op) const { +bool JoinFilterPushdownInfo::CanUseBloomFilter(const ClientContext &context, optional_ptr ht, + const PhysicalComparisonJoin &op, const ExpressionType &cmp, + const bool is_perfect_hashtable) const { + if (!ht) { + return false; + } + + // with a perfect hashtable we expect good min/max pruning, so we don't want the bloom filter + if (is_perfect_hashtable) { + return false; + } + + // bf is only supported for single key joins with equality condition as the Filter API only allows + // single-column filters so far + const bool can_use_bf = ht->conditions.size() == 1 && cmp == ExpressionType::COMPARE_EQUAL; + + // building the bloom filter is costly on the build to make probing faster, so only use it if there are + // more probing tuples than build tuples + const double build_to_probe_ratio = + static_cast(op.children[0].get().estimated_cardinality) / static_cast(ht->Count()); + const bool probe_larger_then_build = build_to_probe_ratio > 1.0; + + // only use bloom filter if there is no in-filter already + return can_use_bf && build_side_has_filter && probe_larger_then_build; +} + +void JoinFilterPushdownInfo::PushBloomFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, + const PhysicalOperator &op, idx_t filter_col_idx) const { + // If the nulls are equal, we let nulls pass. If not, we filter them + auto filters_null_values = !ht.NullValuesAreEqual(0); + const auto key_name = ht.conditions[0].right->ToString(); + const auto key_type = ht.conditions[0].left->return_type; + auto bf_filter = make_uniq(ht.GetBloomFilter(), filters_null_values, key_name, key_type); + ht.SetBuildBloomFilter(true); + + auto opt_bf_filter = make_uniq( + std::move(bf_filter), SelectivityOptionalFilter::BF_THRESHOLD, SelectivityOptionalFilter::BF_CHECK_N); + info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(opt_bf_filter)); +} + +unique_ptr JoinFilterPushdownInfo::FinalizeMinMax(JoinFilterGlobalState &gstate) const { // finalize the min/max aggregates vector min_max_types; for (auto &aggr_expr : min_max_aggregates) { @@ -758,12 +796,17 @@ unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, o final_min_max->Initialize(Allocator::DefaultAllocator(), min_max_types); gstate.global_aggregate_state->Finalize(*final_min_max); + return final_min_max; +} +unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, optional_ptr ht, + const PhysicalComparisonJoin &op, + unique_ptr final_min_max, + const bool is_perfect_hashtable) const { if (probe_info.empty()) { return final_min_max; // There are not table souces in which we can push down filters } - auto dynamic_or_filter_threshold = DBConfig::GetSetting(context); // create a filter for each of the aggregates for (idx_t filter_idx = 0; filter_idx < join_condition.size(); filter_idx++) { const auto cmp = op.conditions[join_condition[filter_idx]].comparison; @@ -782,11 +825,9 @@ unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, o } // if the HT is small we can generate a complete "OR" filter // but only if the join condition is equality. - if (ht && ht->Count() > 1 && ht->Count() <= dynamic_or_filter_threshold && - cmp == ExpressionType::COMPARE_EQUAL) { + if (ht && CanUseInFilter(context, ht, cmp)) { PushInFilter(info, *ht, op, filter_idx, filter_col_idx); } - if (Value::NotDistinctFrom(min_val, max_val)) { // min = max - single value // generate a "one-sided" comparison filter for the LHS @@ -794,16 +835,20 @@ unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, o auto constant_filter = make_uniq(cmp, std::move(min_val)); info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(constant_filter)); } else { - // min != max - generate a range filter + // min != max - generate a range filter or bloom filter + optional range filter // for non-equalities, the range must be half-open // e.g., for lhs < rhs we can only use lhs <= max + switch (cmp) { case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_GREATERTHAN: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { auto greater_equals = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, std::move(min_val)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(greater_equals)); + auto optional_greater_equals = make_uniq( + std::move(greater_equals), SelectivityOptionalFilter::MIN_MAX_THRESHOLD, + SelectivityOptionalFilter::MIN_MAX_CHECK_N); + info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(optional_greater_equals)); break; } default: @@ -815,19 +860,32 @@ unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, o case ExpressionType::COMPARE_LESSTHANOREQUALTO: { auto less_equals = make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, std::move(max_val)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(less_equals)); + auto optional_less_equals = make_uniq( + std::move(less_equals), SelectivityOptionalFilter::MIN_MAX_THRESHOLD, + SelectivityOptionalFilter::MIN_MAX_CHECK_N); + info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(optional_less_equals)); break; } default: break; } + + if (ht && CanUseBloomFilter(context, ht, op, cmp, is_perfect_hashtable)) { + PushBloomFilter(info, *ht, op, filter_col_idx); + } } } } - return final_min_max; } +unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, optional_ptr ht, + JoinFilterGlobalState &gstate, + const PhysicalComparisonJoin &op) const { + auto final_min_max = FinalizeMinMax(gstate); + return FinalizeFilters(context, ht, op, std::move(final_min_max), false); +} + SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, OperatorSinkFinalizeInput &input) const { auto &sink = input.global_state.Cast(); @@ -900,10 +958,12 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl Value min; Value max; + unique_ptr filter_min_max = nullptr; + if (filter_pushdown && !sink.skip_filter_pushdown && ht.Count() > 0) { - auto final_min_max = filter_pushdown->Finalize(context, &ht, *sink.global_filter_state, *this); - min = final_min_max->data[0].GetValue(0); - max = final_min_max->data[1].GetValue(0); + filter_min_max = filter_pushdown->FinalizeMinMax(*sink.global_filter_state); + min = filter_min_max->data[0].GetValue(0); + max = filter_min_max->data[1].GetValue(0); } else if (TypeIsIntegral(conditions[0].right->return_type.InternalType())) { min = Value::MinimumValue(conditions[0].right->return_type); max = Value::MaximumValue(conditions[0].right->return_type); @@ -916,6 +976,11 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl auto key_type = ht.equality_types[0]; use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(key_type); } + + if (filter_min_max) { + filter_pushdown->FinalizeFilters(context, &ht, *this, std::move(filter_min_max), use_perfect_hash); + } + // In case of a large build side or duplicates, use regular hash join if (!use_perfect_hash) { sink.perfect_join_executor.reset(); @@ -1159,7 +1224,8 @@ unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionCont HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context) : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), probe_chunk_done(0), probe_count(op.children[0].get().estimated_cardinality), - parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { + parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120), full_outer_chunk_count(0), + full_outer_chunk_done(0) { } void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { @@ -1439,8 +1505,8 @@ void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, Has } } -SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalHashJoin::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &sink = sink_state->Cast(); auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index 90aba4722..bfd37e09b 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -1,15 +1,8 @@ -#include - #include "duckdb/execution/operator/join/physical_iejoin.hpp" #include "duckdb/common/atomic.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/thread.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" #include "duckdb/parallel/meta_pipeline.hpp" @@ -17,6 +10,8 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include + namespace duckdb { PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, @@ -24,7 +19,6 @@ PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoi idx_t estimated_cardinality, unique_ptr pushdown_info) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::IE_JOIN, left, right, std::move(cond), join_type, estimated_cardinality, std::move(pushdown_info)) { - // 1. let L1 (resp. L2) be the array of column X (resp. Y) D_ASSERT(conditions.size() >= 2); for (idx_t i = 0; i < 2; ++i) { @@ -82,17 +76,15 @@ class IEJoinGlobalState : public GlobalSinkState { public: IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(1) { tables.resize(2); - RowLayout lhs_layout; - lhs_layout.Initialize(op.children[0].get().GetTypes()); + const auto &lhs_types = op.children[0].get().GetTypes(); vector lhs_order; lhs_order.emplace_back(op.lhs_orders[0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout, op); + tables[0] = make_uniq(context, lhs_order, lhs_types, op); - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout, op); + tables[1] = make_uniq(context, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); @@ -100,11 +92,18 @@ class IEJoinGlobalState : public GlobalSinkState { } } - void Sink(DataChunk &input, IEJoinLocalState &lstate); - void Finalize(Pipeline &pipeline, Event &event) { + void Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate); + + void Finalize(ClientContext &client, InterruptState &interrupt) { + // Sort the current input child + D_ASSERT(child < tables.size()); + tables[child]->Finalize(client, interrupt); + }; + + void Materialize(Pipeline &pipeline, Event &event) { // Sort the current input child D_ASSERT(child < tables.size()); - tables[child]->Finalize(pipeline, event); + tables[child]->Materialize(pipeline, event); child = child ? 0 : 2; skip_filter_pushdown = true; }; @@ -123,9 +122,8 @@ class IEJoinLocalState : public LocalSinkState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) - : table(context, op, gstate.child) { - + IEJoinLocalState(ExecutionContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) + : table(context, *gstate.tables[gstate.child], gstate.child) { if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); } @@ -144,32 +142,23 @@ unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &co unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { auto &ie_sink = sink_state->Cast(); - return make_uniq(context.client, *this, ie_sink); + return make_uniq(context, *this, ie_sink); } -void IEJoinGlobalState::Sink(DataChunk &input, IEJoinLocalState &lstate) { - auto &table = *tables[child]; - auto &global_sort_state = table.global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void IEJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - if (gstate.child == 0 && gstate.tables[1]->global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 0 && gstate.tables[1]->Count() == 0 && EmptyResultIfRHSIsEmpty()) { return SinkResultType::FINISHED; } - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -181,7 +170,7 @@ SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.tables[gstate.child]->Combine(lstate.table); + gstate.tables[gstate.child]->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -197,14 +186,13 @@ SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, Operato //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } auto &table = *gstate.tables[gstate.child]; - auto &global_sort_state = table.global_sort_state; if ((gstate.child == 1 && PropagatesBuildSide(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple @@ -212,15 +200,18 @@ SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, Clie } SinkFinalizeType res; - if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 1 && table.Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! res = SinkFinalizeType::NO_OUTPUT_POSSIBLE; } else { res = SinkFinalizeType::READY; } + // Clean up the current table + gstate.Finalize(client, input.interrupt_state); + // Move to the next input child - gstate.Finalize(pipeline, event); + gstate.Materialize(pipeline, event); return res; } @@ -236,21 +227,156 @@ OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, Da //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// +enum class IEJoinSourceStage : uint8_t { INIT, SORT_L1, MATERIALIZE_L1, SORT_L2, MATERIALIZE_L2, INNER, OUTER, DONE }; + +struct IEJoinSourceTask { + using ChunkRange = std::pair; + + IEJoinSourceTask() { + } + + IEJoinSourceStage stage = IEJoinSourceStage::DONE; + //! The thread index (for local state) + idx_t thread_idx = 0; + //! The chunk range + ChunkRange l_range; + //! The right chunk range + ChunkRange r_range; +}; + +class IEJoinLocalSourceState; + +class IEJoinGlobalSourceState : public GlobalSourceState { +public: + using Task = IEJoinSourceTask; + using TaskPtr = optional_ptr; + using SortedTable = PhysicalRangeJoin::GlobalSortedTable; + + IEJoinGlobalSourceState(const PhysicalIEJoin &op, ClientContext &client, IEJoinGlobalState &gsink); + + template + static T BinValue(T n, T val) { + return ((n + (val - 1)) / val); + } + + idx_t GetStageCount(IEJoinSourceStage stage) const { + return stage_tasks[size_t(stage)]; + } + + atomic &GetStageNext(IEJoinSourceStage stage) { + return completed[size_t(stage)]; + } + + //! The processing stage for this group + IEJoinSourceStage GetStage() const { + return stage; + } + + //! The total number of tasks we will execute + idx_t GetTaskCount() const { + return stage_begin[size_t(IEJoinSourceStage::DONE)]; + } + + //! Are there any more tasks? + bool HasMoreTasks() const { + return !stopped && started < total_tasks; + } + bool HasUnfinishedTasks() const { + return !stopped && finished < total_tasks; + } + + bool TryPrepareNextStage(); + bool TryNextTask(TaskPtr &task, Task &task_local); + + void FinishL1Task(); + void FinishL2Task(); + +public: + idx_t MaxThreads() override; + + ProgressData GetProgress() const; + + const PhysicalIEJoin &op; + IEJoinGlobalState &gsink; + + //! The processing stage + IEJoinSourceStage stage = IEJoinSourceStage::INIT; + //! The the number of tasks per stage. + vector stage_tasks; + //! The the first task in the stage. + vector stage_begin; + //! The next task to process + idx_t next_task = 0; + //! The total number of tasks + idx_t total_tasks = 0; + //! The number of started tasks + atomic started; + //! The number of tasks finished. + atomic finished; + //! Stop producing tasks + atomic stopped; + //! The number of completed tasks for each stage + array, size_t(IEJoinSourceStage::DONE)> completed; + + //! L1 + unique_ptr l1; + //! L2 + unique_ptr l2; + //! Li + vector li; + //! P + vector p; + + // Join queue state + idx_t join_blocks = 0; + idx_t per_thread = 0; + idx_t left_blocks = 0; + idx_t right_blocks = 0; + + // Outer joins + idx_t left_outers = 0; + idx_t right_outers = 0; + +protected: + void Initialize(); + void FinishTask(TaskPtr task); + bool TryNextTask(Task &task); +}; + struct IEJoinUnion { using SortedTable = PhysicalRangeJoin::GlobalSortedTable; + using ChunkRange = std::pair; + + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw InternalException("Unimplemented comparison type for IEJoin!"); + } + } - static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx); - - static void Sort(SortedTable &table) { - auto &global_sort_state = table.global_sort_state; - global_sort_state.PrepareMergePhase(); - while (global_sort_state.sorted_blocks.size() > 1) { - global_sort_state.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort_state.CompleteMergeRound(true); + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); } + return less_than; + } + + static idx_t AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &range); + + static void Sort(ExecutionContext &context, InterruptState &interrupt, SortedTable &table) { + table.Finalize(context.client, interrupt); + table.Materialize(context, interrupt); } template @@ -258,21 +384,17 @@ struct IEJoinUnion { vector result; result.reserve(table.count); - auto &gstate = table.global_sort_state; - auto &blocks = *gstate.sorted_blocks[0]->payload_data; - PayloadScanner scanner(blocks, gstate, false); + auto &collection = *table.sorted->payload_data; + vector scan_ids(1, col_idx); + TupleDataScanState state; + collection.InitializeScan(state, scan_ids); DataChunk payload; - payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); - for (;;) { - payload.Reset(); - scanner.Scan(payload); - const auto count = payload.size(); - if (!count) { - break; - } + collection.InitializeScanChunk(state, payload); - const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); + while (collection.Scan(state, payload)) { + const auto count = payload.size(); + const auto data_ptr = FlatVector::GetData(payload.data[0]); for (idx_t i = 0; i < count; i++) { result.push_back(UnsafeNumericCast(data_ptr[i])); } @@ -281,29 +403,48 @@ struct IEJoinUnion { return result; } - IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, - const idx_t b2); + class UnionIterator { + public: + UnionIterator(SortedTable &table, bool strict) : state(table.CreateIteratorState()), strict(strict) { + } + + inline idx_t GetIndex() const { + return index; + } + + inline void SetIndex(idx_t i) { + index = i; + } + + UnionIterator &operator++() { + ++index; + return *this; + } + + unique_ptr state; + idx_t index = 0; + const bool strict; + }; + + IEJoinUnion(IEJoinGlobalSourceState &gsource, const ChunkRange &chunks); idx_t SearchL1(idx_t pos); + + template bool NextRow(); - //! Inverted loop - idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); + using next_row_t = bool (duckdb::IEJoinUnion::*)(); + next_row_t next_row_func; - //! L1 - unique_ptr l1; - //! L2 - unique_ptr l2; + //! Constructor arguments + IEJoinGlobalSourceState &gsource; - //! Li - vector li; - //! P - vector p; + //! Inverted loop + idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); //! B vector bit_array; ValidityMask bit_mask; - //! Bloom Filter static constexpr idx_t BLOOM_CHUNK_BITS = 1024; idx_t bloom_count; @@ -313,50 +454,64 @@ struct IEJoinUnion { //! Iteration state idx_t n; idx_t i; + idx_t n_j; idx_t j; - unique_ptr op1; - unique_ptr off1; - unique_ptr op2; - unique_ptr off2; + unique_ptr op1; + unique_ptr off1; + unique_ptr op2; + unique_ptr off2; int64_t lrid; }; -idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx) { - LocalSortState local_sort_state; - local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); +idx_t IEJoinUnion::AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &chunk_range) { + const auto chunk_begin = chunk_range.first; + const auto chunk_end = chunk_range.second; + + if (chunk_begin == chunk_end) { + return 0; + } // Reading const auto valid = table.count - table.has_null; - auto &gstate = table.global_sort_state; - PayloadScanner scanner(gstate, block_idx); - auto table_idx = block_idx * gstate.block_capacity; + auto &source = *table.sorted->payload_data; + TupleDataScanState scanner; + source.InitializeScan(scanner); DataChunk scanned; - scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); + source.InitializeScanChunk(scanner, scanned); + idx_t table_idx = source.Seek(scanner, chunk_begin); // Writing - auto types = local_sort_state.sort_layout->logical_types; - const idx_t payload_idx = types.size(); - - const auto &payload_types = local_sort_state.payload_layout->GetTypes(); - types.insert(types.end(), payload_types.begin(), payload_types.end()); - const idx_t rid_idx = types.size() - 1; + auto &sort = *marked.sort; + auto local_sort_state = sort.GetLocalSinkState(context); + vector types; + for (const auto &expr : executor.expressions) { + types.emplace_back(expr->return_type); + } + const idx_t rid_idx = types.size(); + types.emplace_back(LogicalType::BIGINT); DataChunk keys; DataChunk payload; keys.Initialize(Allocator::DefaultAllocator(), types); + OperatorSinkInput sink {*marked.global_sink, *local_sort_state, interrupt}; idx_t inserted = 0; - for (auto rid = base; table_idx < valid;) { - scanned.Reset(); - scanner.Scan(scanned); + for (auto chunk_idx = chunk_begin; chunk_idx < chunk_end; ++chunk_idx) { + source.Scan(scanner, scanned); // NULLs are at the end, so stop when we reach them auto scan_count = scanned.size(); if (table_idx + scan_count > valid) { - scan_count = valid - table_idx; - scanned.SetCardinality(scan_count); + if (table_idx >= valid) { + scan_count = 0; + ; + } else { + scan_count = valid - table_idx; + scanned.SetCardinality(scan_count); + } } if (scan_count == 0) { break; @@ -375,166 +530,90 @@ idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, S rid += increment * UnsafeNumericCast(scan_count); // Sort on the sort columns (which will no longer be needed) - keys.Split(payload, payload_idx); - local_sort_state.SinkChunk(keys, payload); + sort.Sink(context, keys, sink); inserted += scan_count; - keys.Fuse(payload); - - // Flush when we have enough data - if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { - local_sort_state.Sort(marked.global_sort_state, true); - } } - marked.global_sort_state.AddLocalState(local_sort_state); + OperatorSinkCombineInput combine {*marked.global_sink, *local_sort_state, interrupt}; + sort.Combine(context, combine); marked.count += inserted; return inserted; } -IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, - SortedTable &t2, const idx_t b2) - : n(0), i(0) { - // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. - // output: a list of tuple pairs (ti , tj) - // Note that T/T' are already sorted on X/X' and contain the payload data - // We only join the two block numbers and use the sizes of the blocks as the counts - - // 0. Filter out tables with no overlap - if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { - return; - } - - const auto &cmp1 = op.conditions[0].comparison; - SBIterator bounds1(t1.global_sort_state, cmp1); - SBIterator bounds2(t2.global_sort_state, cmp1); - - // t1.X[0] op1 t2.X'[-1] - bounds1.SetIndex(bounds1.block_capacity * b1); - bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); - if (!bounds1.Compare(bounds2)) { - return; - } - - // 1. let L1 (resp. L2) be the array of column X (resp. Y ) - const auto &order1 = op.lhs_orders[0]; - const auto &order2 = op.lhs_orders[1]; - - // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order - // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order - - // For the union algorithm, we make a unified table with the keys and the rids as the payload: - // X/X', Y/Y', R/R'/Li - // The first position is the sort key. - vector types; - types.emplace_back(order2.expression->return_type); - types.emplace_back(LogicalType::BIGINT); - RowLayout payload_layout; - payload_layout.Initialize(types); - - // Sort on the first expression - auto ref = make_uniq(order1.expression->return_type, 0U); - vector orders; - orders.emplace_back(order1.type, order1.null_order, std::move(ref)); - // The goal is to make i (from the left table) < j (from the right table), - // if value[i] and value[j] match the condition 1. - // Add a column from_left to solve the problem when there exist multiple equal values in l1. - // If the operator is loose inequality, make t1.from_left (== true) sort BEFORE t2.from_left (== false). - // Otherwise, make t1.from_left sort (== true) sort AFTER t2.from_left (== false). - // For example, if t1.time <= t2.time - // | value | 1 | 1 | 1 | 1 | - // | --------- | ----- | ----- | ----- | ----- | - // | from_left | T(l2) | T(l2) | F(r1) | F(r2) | - // if t1.time < t2.time - // | value | 1 | 1 | 1 | 1 | - // | --------- | ----- | ----- | ----- | ----- | - // | from_left | F(r2) | F(r1) | T(l2) | T(l1) | - // Using this OrderType, if i < j then value[i] (from left table) and value[j] (from right table) match - // the condition (t1.time <= t2.time or t1.time < t2.time), then from_left will force them into the correct order. - auto from_left = make_uniq(Value::BOOLEAN(true)); - orders.emplace_back(SBIterator::ComparisonValue(cmp1) == 0 ? OrderType::DESCENDING : OrderType::ASCENDING, - OrderByNullType::ORDER_DEFAULT, std::move(from_left)); - - l1 = make_uniq(context, orders, payload_layout, op); - - // LHS has positive rids - ExpressionExecutor l_executor(context); - l_executor.AddExpression(*order1.expression); - // add const column true - auto left_const = make_uniq(Value::BOOLEAN(true)); - l_executor.AddExpression(*left_const); - l_executor.AddExpression(*order2.expression); - AppendKey(t1, l_executor, *l1, 1, 1, b1); - - // RHS has negative rids - ExpressionExecutor r_executor(context); - r_executor.AddExpression(*op.rhs_orders[0].expression); - // add const column flase - auto right_const = make_uniq(Value::BOOLEAN(false)); - r_executor.AddExpression(*right_const); - r_executor.AddExpression(*op.rhs_orders[1].expression); - AppendKey(t2, r_executor, *l1, -1, -1, b2); - - if (l1->global_sort_state.sorted_blocks.empty()) { - return; - } - - Sort(*l1); - - op1 = make_uniq(l1->global_sort_state, cmp1); - off1 = make_uniq(l1->global_sort_state, cmp1); - - // We don't actually need the L1 column, just its sort key, which is in the sort blocks - li = ExtractColumn(*l1, types.size() - 1); - - // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order - // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order - - // We sort on Y/Y' to obtain the sort keys and the permutation array. - // For this we just need a two-column table of Y, P - types.clear(); - types.emplace_back(LogicalType::BIGINT); - payload_layout.Initialize(types); - - // Sort on the first expression - orders.clear(); - ref = make_uniq(order2.expression->return_type, 0U); - orders.emplace_back(order2.type, order2.null_order, std::move(ref)); - - ExpressionExecutor executor(context); - executor.AddExpression(*orders[0].expression); - - l2 = make_uniq(context, orders, payload_layout, op); - for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); - } - - Sort(*l2); - - // We don't actually need the L2 column, just its sort key, which is in the sort blocks - - // 6. compute the permutation array P of L2 w.r.t. L1 - p = ExtractColumn(*l2, types.size() - 1); +IEJoinUnion::IEJoinUnion(IEJoinGlobalSourceState &gsource, const ChunkRange &chunks) : gsource(gsource), n(0), i(0) { + auto &op = gsource.op; + auto &l1 = *gsource.l1; + const auto strict1 = IsStrictComparison(op.conditions[0].comparison); + op1 = make_uniq(l1, strict1); + off1 = make_uniq(l1, strict1); // 7. initialize bit-array B (|B| = n), and set all bits to 0 - n = l2->count.load(); - bit_array.resize(ValidityMask::EntryCount(n), 0); - bit_mask.Initialize(bit_array.data(), n); + auto &l2 = *gsource.l2; + n_j = l2.count.load(); + bit_array.resize(ValidityMask::EntryCount(n_j), 0); + bit_mask.Initialize(bit_array.data(), n_j); // Bloom filter - bloom_count = (n + (BLOOM_CHUNK_BITS - 1)) / BLOOM_CHUNK_BITS; + bloom_count = (n_j + (BLOOM_CHUNK_BITS - 1)) / BLOOM_CHUNK_BITS; bloom_array.resize(ValidityMask::EntryCount(bloom_count), 0); bloom_filter.Initialize(bloom_array.data(), bloom_count); // 11. for(i←1 to n) do - const auto &cmp2 = op.conditions[1].comparison; - op2 = make_uniq(l2->global_sort_state, cmp2); - off2 = make_uniq(l2->global_sort_state, cmp2); - i = 0; - j = 0; - (void)NextRow(); + const auto strict2 = IsStrictComparison(op.conditions[1].comparison); + op2 = make_uniq(l2, strict2); + off2 = make_uniq(l2, strict2); + n = l2.BlockStart(chunks.second); + i = l2.BlockStart(chunks.first); + j = i; + + const auto sort_key_type = l2.GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + default: + throw NotImplementedException("IEJoinUnion for %s", EnumUtil::ToString(sort_key_type)); + } + + (this->*next_row_func)(); } +template bool IEJoinUnion::NextRow() { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + + BLOCKS_ITERATOR off2_itr(*off2->state); + BLOCKS_ITERATOR op2_itr(*op2->state); + const auto strict = off2->strict; + + auto &li = gsource.li; + auto &p = gsource.p; + for (; i < n; ++i) { // 12. pos ← P[i] auto pos = p[i]; @@ -545,8 +624,8 @@ bool IEJoinUnion::NextRow() { // 16. B[pos] ← 1 op2->SetIndex(i); - for (; off2->GetIndex() < n; ++(*off2)) { - if (!off2->Compare(*op2)) { + for (; off2->GetIndex() < n_j; ++(*off2)) { + if (!Compare(off2_itr[off2->GetIndex()], op2_itr[op2->GetIndex()], strict)) { break; } const auto p2 = p[off2->GetIndex()]; @@ -611,6 +690,8 @@ static idx_t NextValid(const ValidityMask &bits, idx_t j, const idx_t n) { } idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel) { + auto &li = gsource.li; + // 8. initialize join result as an empty list for tuple pairs idx_t result_count = 0; @@ -621,9 +702,9 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse // 14. if B[j] = 1 then // Use the Bloom filter to find candidate blocks - while (j < n) { + while (j < n_j) { auto bloom_begin = NextValid(bloom_filter, j / BLOOM_CHUNK_BITS, bloom_count) * BLOOM_CHUNK_BITS; - auto bloom_end = MinValue(n, bloom_begin + BLOOM_CHUNK_BITS); + auto bloom_end = MinValue(n_j, bloom_begin + BLOOM_CHUNK_BITS); j = MaxValue(j, bloom_begin); j = NextValid(bit_mask, j, bloom_end); @@ -632,7 +713,7 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse } } - if (j >= n) { + if (j >= n_j) { break; } @@ -652,7 +733,7 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse } ++i; - if (!NextRow()) { + if (!(this->*next_row_func)()) { break; } } @@ -660,13 +741,134 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse return result_count; } +IEJoinGlobalSourceState::IEJoinGlobalSourceState(const PhysicalIEJoin &op, ClientContext &client, + IEJoinGlobalState &gsink) + : op(op), gsink(gsink), stage(IEJoinSourceStage::INIT), started(0), finished(0), stopped(false) { + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + left_blocks = left_table.BlockCount(); + if (left_table.found_match) { + left_outers = left_blocks; + } + + right_blocks = right_table.BlockCount(); + if (right_table.found_match) { + right_outers = right_blocks; + } + + // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. + // output: a list of tuple pairs (ti , tj) + // Note that T/T' are already sorted on X/X' and contain the payload data + // We only join the two block numbers and use the sizes of the blocks as the counts + + // 1. let L1 (resp. L2) be the array of column X (resp. Y ) + const auto &order1 = op.lhs_orders[0]; + const auto &order2 = op.lhs_orders[1]; + + // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order + // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order + + // For the union algorithm, we make a unified table with the keys and the rids as the payload: + // X/X', Y/Y', R/R'/Li + // The first position is the sort key. + vector types; + types.emplace_back(order2.expression->return_type); + types.emplace_back(LogicalType::BIGINT); + + // Sort on the first expression + auto ref = make_uniq(order1.expression->return_type, 0U); + vector orders; + orders.emplace_back(order1.type, order1.null_order, std::move(ref)); + // The goal is to make i (from the left table) < j (from the right table), + // if value[i] and value[j] match the condition 1. + // Add a column from_left to solve the problem when there exist multiple equal values in l1. + // If the operator is loose inequality, make t1.from_left (== true) sort BEFORE t2.from_left (== false). + // Otherwise, make t1.from_left sort (== true) sort AFTER t2.from_left (== false). + // For example, if t1.time <= t2.time + // | value | 1 | 1 | 1 | 1 | + // | --------- | ----- | ----- | ----- | ----- | + // | from_left | T(l2) | T(l2) | F(r1) | F(r2) | + // if t1.time < t2.time + // | value | 1 | 1 | 1 | 1 | + // | --------- | ----- | ----- | ----- | ----- | + // | from_left | F(r2) | F(r1) | T(l2) | T(l1) | + // Using this OrderType, if i < j then value[i] (from left table) and value[j] (from right table) match + // the condition (t1.time <= t2.time or t1.time < t2.time), then from_left will force them into the correct order. + auto from_left = make_uniq(Value::BOOLEAN(true)); + const auto strict1 = IEJoinUnion::IsStrictComparison(op.conditions[0].comparison); + orders.emplace_back(!strict1 ? OrderType::DESCENDING : OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, + std::move(from_left)); + + l1 = make_uniq(client, orders, types, op); + + // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order + // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order + + // We sort on Y/Y' to obtain the sort keys and the permutation array. + // For this we just need a two-column table of Y, P + types.clear(); + types.emplace_back(LogicalType::BIGINT); + + // Sort on the first expression + orders.clear(); + ref = make_uniq(order2.expression->return_type, 0U); + orders.emplace_back(order2.type, order2.null_order, std::move(ref)); + + l2 = make_uniq(client, orders, types, op); + + // The number of blocks in L2 is not quite the sum of the blocks in the two tables... + const auto join_count = left_table.count.load() + right_table.count.load(); + join_blocks = BinValue(join_count, STANDARD_VECTOR_SIZE); + + // Schedule the largest group on as many threads as possible + auto &ts = TaskScheduler::GetScheduler(client); + const auto threads = NumericCast(ts.NumberOfThreads()); + per_thread = BinValue(join_blocks, threads); + + Initialize(); +} + +void IEJoinGlobalSourceState::FinishL1Task() { + // We don't actually need the L1 column, just its sort key, which is in the sort blocks + li = IEJoinUnion::ExtractColumn(*l1, 1); +} + +void IEJoinGlobalSourceState::FinishL2Task() { + // We don't actually need the L2 column, just its sort key, which is in the sort blocks + + // 6. compute the permutation array P of L2 w.r.t. L1 + p = IEJoinUnion::ExtractColumn(*l2, 0); +} + class IEJoinLocalSourceState : public LocalSourceState { public: - explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) - : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), - left_matches(nullptr), right_matches(nullptr) { - auto &allocator = Allocator::Get(context); - unprojected.Initialize(allocator, op.unprojected_types); + using Task = IEJoinSourceTask; + using TaskPtr = optional_ptr; + + IEJoinLocalSourceState(ClientContext &client, IEJoinGlobalSourceState &gsource) + : gsource(gsource), lsel(STANDARD_VECTOR_SIZE), rsel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), + left_executor(client), right_executor(client), left_matches(nullptr), right_matches(nullptr) + + { + auto &op = gsource.op; + auto &allocator = Allocator::Get(client); + unprojected.InitializeEmpty(op.unprojected_types); + lpayload.Initialize(allocator, op.children[0].get().GetTypes()); + rpayload.Initialize(allocator, op.children[1].get().GetTypes()); + + auto &ie_sink = op.sink_state->Cast(); + auto &left_table = *ie_sink.tables[0]; + auto &right_table = *ie_sink.tables[1]; + + left_iterator = left_table.CreateIteratorState(); + right_iterator = right_table.CreateIteratorState(); + + left_table.InitializePayloadState(left_chunk_state); + right_table.InitializePayloadState(right_chunk_state); + + left_scan_state = left_table.CreateScanState(client); + right_scan_state = right_table.CreateScanState(client); if (op.conditions.size() < 3) { return; @@ -703,16 +905,54 @@ class IEJoinLocalSourceState : public LocalSourceState { return count; } - const PhysicalIEJoin &op; + // Are we executing a task? + bool TaskFinished() const { + return !joiner && !left_matches && !right_matches; + } + + bool TryAssignTask(); + // Sort L1 + void ExecuteSortL1Task(ExecutionContext &context, InterruptState &interrupt); + // Materialize L1 + void ExecuteMaterializeL1Task(ExecutionContext &context, InterruptState &interrupt); + // Sort L2 + void ExecuteSortL2Task(ExecutionContext &context, InterruptState &interrupt); + // Materialize L2 + void ExecuteMaterializeL2Task(ExecutionContext &context, InterruptState &interrupt); + // resolve joins that can potentially output N*M elements (INNER, LEFT, RIGHT, FULL) + void ResolveComplexJoin(ExecutionContext &context, DataChunk &result); + // Resolve left join results + void ExecuteLeftTask(ExecutionContext &context, DataChunk &result); + // Resolve right join results + void ExecuteRightTask(ExecutionContext &context, DataChunk &result); + // Execute the current task + void ExecuteTask(ExecutionContext &context, DataChunk &result, InterruptState &interrupt); + + IEJoinGlobalSourceState &gsource; + + //! The task this thread is working on + TaskPtr task; + //! The task storage + Task task_local; // Joining unique_ptr joiner; idx_t left_base; idx_t left_block_index; + unique_ptr left_iterator; + TupleDataChunkState left_chunk_state; + SelectionVector lsel; + DataChunk lpayload; + unique_ptr left_scan_state; idx_t right_base; idx_t right_block_index; + unique_ptr right_iterator; + TupleDataChunkState right_chunk_state; + SelectionVector rsel; + DataChunk rpayload; + unique_ptr right_scan_state; // Trailing predicates SelectionVector true_sel; @@ -732,262 +972,499 @@ class IEJoinLocalSourceState : public LocalSourceState { bool *right_matches; }; -void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { - auto &state = state_p.Cast(); - auto &ie_sink = sink_state->Cast(); +bool IEJoinLocalSourceState::TryAssignTask() { + // Because downstream operators may be using our internal buffers, + // we can't "finish" a task until we are about to get the next one. + if (task) { + switch (task->stage) { + case IEJoinSourceStage::SORT_L1: + ++gsource.GetStageNext(task->stage); + break; + case IEJoinSourceStage::MATERIALIZE_L1: + gsource.FinishL1Task(); + ++gsource.GetStageNext(task->stage); + break; + case IEJoinSourceStage::SORT_L2: + ++gsource.GetStageNext(task->stage); + break; + case IEJoinSourceStage::MATERIALIZE_L2: + gsource.FinishL2Task(); + ++gsource.GetStageNext(task->stage); + break; + case IEJoinSourceStage::INNER: + ++gsource.GetStageNext(task->stage); + break; + case IEJoinSourceStage::OUTER: + ++gsource.GetStageNext(task->stage); + left_matches = nullptr; + right_matches = nullptr; + break; + case IEJoinSourceStage::INIT: + case IEJoinSourceStage::DONE: + break; + } + } + + if (!gsource.TryNextTask(task, task_local)) { + return false; + } + + auto &gsink = gsource.gsink; + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + switch (task->stage) { + case IEJoinSourceStage::SORT_L1: + case IEJoinSourceStage::MATERIALIZE_L1: + case IEJoinSourceStage::SORT_L2: + case IEJoinSourceStage::MATERIALIZE_L2: + break; + case IEJoinSourceStage::INNER: + // The join can hit any block on either side + left_block_index = 0; + left_base = 0; + + right_block_index = 0; + right_base = 0; + + joiner = make_uniq(gsource, task->l_range); + break; + case IEJoinSourceStage::OUTER: + if (task->thread_idx < gsource.left_outers) { + left_block_index = task->l_range.first; + left_base = left_table.BlockStart(left_block_index); + + left_matches = left_table.found_match.get() + left_base; + outer_idx = 0; + outer_count = left_table.BlockSize(left_block_index); + } else { + right_block_index = task->r_range.first; + right_base = right_table.BlockStart(right_block_index); + + right_matches = right_table.found_match.get() + right_base; + outer_idx = 0; + outer_count = right_table.BlockSize(right_block_index); + } + break; + case IEJoinSourceStage::INIT: + case IEJoinSourceStage::DONE: + break; + } + + return true; +} + +void IEJoinLocalSourceState::ExecuteSortL1Task(ExecutionContext &context, InterruptState &interrupt) { + auto &gsink = gsource.gsink; + auto &left_table = *gsink.tables[0]; + auto &right_table = *gsink.tables[1]; + + auto &op = gsource.op; + const auto &order1 = op.lhs_orders[0]; + const auto &order2 = op.lhs_orders[1]; + + auto &l1 = gsource.l1; + + // LHS has positive rids + ExpressionExecutor l_executor(context.client); + l_executor.AddExpression(*order1.expression); + // add const column true + auto left_const = make_uniq(Value::BOOLEAN(true)); + l_executor.AddExpression(*left_const); + l_executor.AddExpression(*order2.expression); + IEJoinUnion::AppendKey(context, interrupt, left_table, l_executor, *l1, 1, 1, task->l_range); + task->l_range.first = task->l_range.second; + + // RHS has negative rids + ExpressionExecutor r_executor(context.client); + r_executor.AddExpression(*op.rhs_orders[0].expression); + // add const column flase + auto right_const = make_uniq(Value::BOOLEAN(false)); + r_executor.AddExpression(*right_const); + r_executor.AddExpression(*op.rhs_orders[1].expression); + IEJoinUnion::AppendKey(context, interrupt, right_table, r_executor, *l1, -1, -1, task->r_range); + task->r_range.first = task->r_range.second; +} + +void IEJoinLocalSourceState::ExecuteMaterializeL1Task(ExecutionContext &context, InterruptState &interrupt) { + IEJoinUnion::Sort(context, interrupt, *gsource.l1); +} + +void IEJoinLocalSourceState::ExecuteSortL2Task(ExecutionContext &context, InterruptState &interrupt) { + auto &l1 = *gsource.l1; + auto &l2 = *gsource.l2; + + auto &op = gsource.op; + const auto &order2 = op.lhs_orders[1]; + auto ref = make_uniq(order2.expression->return_type, 0U); + + ExpressionExecutor executor(context.client); + executor.AddExpression(*ref); + IEJoinUnion::AppendKey(context, interrupt, l1, executor, l2, 1, 0, task->l_range); + + // Mark task as done + task->l_range.first = task->l_range.second; +} + +void IEJoinLocalSourceState::ExecuteMaterializeL2Task(ExecutionContext &context, InterruptState &interrupt) { + IEJoinUnion::Sort(context, interrupt, *gsource.l2); + task->l_range.first = task->l_range.second; +} + +void IEJoinLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &result, InterruptState &interrupt) { + switch (task->stage) { + case IEJoinSourceStage::INIT: + case IEJoinSourceStage::DONE: + break; + case IEJoinSourceStage::SORT_L1: + ExecuteSortL1Task(context, interrupt); + break; + case IEJoinSourceStage::MATERIALIZE_L1: + ExecuteMaterializeL1Task(context, interrupt); + break; + case IEJoinSourceStage::SORT_L2: + ExecuteSortL2Task(context, interrupt); + break; + case IEJoinSourceStage::MATERIALIZE_L2: + ExecuteMaterializeL2Task(context, interrupt); + break; + case IEJoinSourceStage::INNER: + ResolveComplexJoin(context, result); + break; + case IEJoinSourceStage::OUTER: + if (left_matches != nullptr) { + ExecuteLeftTask(context, result); + } else if (right_matches != nullptr) { + ExecuteRightTask(context, result); + } + break; + } +} + +void IEJoinLocalSourceState::ResolveComplexJoin(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto &conditions = op.conditions; + + auto &chunk = unprojected; + auto &left_table = *ie_sink.tables[0]; + const auto left_cols = op.children[0].get().GetTypes().size(); + auto &right_table = *ie_sink.tables[1]; - const auto left_cols = children[0].get().GetTypes().size(); - auto &chunk = state.unprojected; do { - SelectionVector lsel(STANDARD_VECTOR_SIZE); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); + auto result_count = joiner->JoinComplexBlocks(lsel, rsel); if (result_count == 0) { // exhausted this pair + joiner.reset(); return; } // found matches: extract them - chunk.Reset(); - SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); - SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, - left_cols); - chunk.SetCardinality(result_count); + left_table.Repin(*left_iterator); + right_table.Repin(*right_iterator); + + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, lsel, + result_count, *left_scan_state); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, + result_count, *right_scan_state); auto sel = FlatVector::IncrementalSelectionVector(); if (conditions.size() > 2) { // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. + // use the left and right payloads + // to we can compute the values for comparison. const auto tail_cols = conditions.size() - 2; - DataChunk right_chunk; - chunk.Split(right_chunk, left_cols); - state.left_executor.SetChunk(chunk); - state.right_executor.SetChunk(right_chunk); + left_executor.SetChunk(lpayload); + right_executor.SetChunk(rpayload); auto tail_count = result_count; - auto true_sel = &state.true_sel; + auto match_sel = &true_sel; for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { - auto &left = state.left_keys.data[cmp_idx]; - state.left_executor.ExecuteExpression(cmp_idx, left); + auto &left = left_keys.data[cmp_idx]; + left_executor.ExecuteExpression(cmp_idx, left); - auto &right = state.right_keys.data[cmp_idx]; - state.right_executor.ExecuteExpression(cmp_idx, right); + auto &right = right_keys.data[cmp_idx]; + right_executor.ExecuteExpression(cmp_idx, right); if (tail_count < result_count) { left.Slice(*sel, tail_count); right.Slice(*sel, tail_count); } - tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); - sel = true_sel; + tail_count = + op.SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, match_sel); + sel = match_sel; } - chunk.Fuse(right_chunk); if (tail_count < result_count) { result_count = tail_count; - chunk.Slice(*sel, result_count); + lpayload.Slice(*sel, result_count); + rpayload.Slice(*sel, result_count); + } + } + + // Merge the payloads + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); } } + chunk.SetCardinality(result_count); // We need all of the data to compute other predicates, // but we only return what is in the projection map - ProjectResult(chunk, result); + op.ProjectResult(chunk, result); // found matches: mark the found matches if required + // NOTE: threadsan reports this as a data race because this can be set concurrently by separate + // threads Technically it is, but it does not matter, since the only value that can be written is + // "true" if (left_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; + left_table.found_match[left_base + lsel[sel->get_index(i)]] = true; } } if (right_table.found_match) { for (idx_t i = 0; i < result_count; i++) { - right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; + right_table.found_match[right_base + rsel[sel->get_index(i)]] = true; } } result.Verify(); } while (result.size() == 0); } -class IEJoinGlobalSourceState : public GlobalSourceState { -public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) - : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), - right_outers(0), next_right(0) { - } +void IEJoinGlobalSourceState::Initialize() { + // INIT + stage_tasks.emplace_back(0); - void Initialize() { - auto guard = Lock(); - if (initialized) { - return; - } - - // Compute the starting row for reach block - // (In theory these are all the same size, but you never know...) - auto &left_table = *gsink.tables[0]; - const auto left_blocks = left_table.BlockCount(); - idx_t left_base = 0; + // SORT_L1 + stage_tasks.emplace_back(1); - for (size_t lhs = 0; lhs < left_blocks; ++lhs) { - left_bases.emplace_back(left_base); - left_base += left_table.BlockSize(lhs); - } + // MATERIALIZE_L1 + stage_tasks.emplace_back(1); - auto &right_table = *gsink.tables[1]; - const auto right_blocks = right_table.BlockCount(); - idx_t right_base = 0; - for (size_t rhs = 0; rhs < right_blocks; ++rhs) { - right_bases.emplace_back(right_base); - right_base += right_table.BlockSize(rhs); - } + // SORT_L2 + stage_tasks.emplace_back(1); - // Outer join block counts - if (left_table.found_match) { - left_outers = left_blocks; - } + // MATERIALIZE_L2 + stage_tasks.emplace_back(1); - if (right_table.found_match) { - right_outers = right_blocks; - } - - // Ready for action - initialized = true; + // INNER + idx_t inner_tasks = 0; + if (per_thread) { + inner_tasks = BinValue(join_blocks, per_thread); } + stage_tasks.emplace_back(inner_tasks); -public: - idx_t MaxThreads() override { - // We can't leverage any more threads than block pairs. - const auto &sink_state = (op.sink_state->Cast()); - return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); - } + // OUTER + stage_tasks.emplace_back(left_outers + right_outers); - void GetNextPair(ClientContext &client, IEJoinLocalSourceState &lstate) { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; + // DONE + stage_tasks.emplace_back(0); - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; + // Accumulate task counts so we can find boundaries reliably + idx_t begin = 0; + for (const auto &stage_task : stage_tasks) { + stage_begin.emplace_back(begin); + begin += stage_task; + } - // Regular block - const auto i = next_pair++; - if (i < pair_count) { - const auto b1 = i / right_blocks; - const auto b2 = i % right_blocks; + total_tasks = stage_begin.back(); - lstate.left_block_index = b1; - lstate.left_base = left_bases[b1]; + // Set all the stage atomic counts to 0 + for (auto &stage_next : completed) { + stage_next = 0; + } - lstate.right_block_index = b2; - lstate.right_base = right_bases[b2]; + // Ready for action + stage = IEJoinSourceStage(1); +} - lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); - return; +bool IEJoinGlobalSourceState::TryPrepareNextStage() { + // Inside lock + const auto stage_count = GetStageCount(stage); + const auto stage_next = GetStageNext(stage).load(); + switch (stage) { + case IEJoinSourceStage::INIT: + stage = IEJoinSourceStage::SORT_L1; + return true; + case IEJoinSourceStage::SORT_L1: + if (stage_next >= stage_count) { + stage = IEJoinSourceStage::MATERIALIZE_L1; + return true; } - - // Outer joins - if (!left_outers && !right_outers) { - return; + break; + case IEJoinSourceStage::MATERIALIZE_L1: + if (stage_next >= stage_count) { + stage = IEJoinSourceStage::SORT_L2; + return true; } - - // Spin wait for regular blocks to finish(!) - while (completed < pair_count) { - std::this_thread::yield(); + break; + case IEJoinSourceStage::SORT_L2: + if (stage_next >= stage_count) { + stage = IEJoinSourceStage::MATERIALIZE_L2; + return true; + } + break; + case IEJoinSourceStage::MATERIALIZE_L2: + if (stage_next >= stage_count) { + stage = IEJoinSourceStage::INNER; + return true; } + break; + case IEJoinSourceStage::INNER: + if (stage_next >= stage_count) { + if (GetStageCount(IEJoinSourceStage::OUTER)) { + stage = IEJoinSourceStage::OUTER; + } else { + stage = IEJoinSourceStage::DONE; + } + return true; + } + break; + case IEJoinSourceStage::OUTER: + if (stage_next >= stage_count) { + stage = IEJoinSourceStage::DONE; + return true; + } + break; + case IEJoinSourceStage::DONE: + return true; + } - // Left outer blocks - const auto l = next_left++; - if (l < left_outers) { - lstate.joiner = nullptr; - lstate.left_block_index = l; - lstate.left_base = left_bases[l]; + return false; +} - lstate.left_matches = left_table.found_match.get() + lstate.left_base; - lstate.outer_idx = 0; - lstate.outer_count = left_table.BlockSize(l); - return; - } else { - lstate.left_matches = nullptr; - } +idx_t IEJoinGlobalSourceState::MaxThreads() { + // We can't leverage any more threads than block pairs. + const auto &sink_state = (op.sink_state->Cast()); + return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); +} - // Right outer block - const auto r = next_right++; - if (r < right_outers) { - lstate.joiner = nullptr; - lstate.right_block_index = r; - lstate.right_base = right_bases[r]; +void IEJoinGlobalSourceState::FinishTask(TaskPtr task) { + // Inside the lock + if (!task) { + return; + } - lstate.right_matches = right_table.found_match.get() + lstate.right_base; - lstate.outer_idx = 0; - lstate.outer_count = right_table.BlockSize(r); - return; - } else { - lstate.right_matches = nullptr; - } + ++finished; +} + +bool IEJoinGlobalSourceState::TryNextTask(TaskPtr &task, Task &task_local) { + auto guard = Lock(); + FinishTask(task); + + if (!HasMoreTasks()) { + task = nullptr; + return false; } - void PairCompleted(ClientContext &client, IEJoinLocalSourceState &lstate) { - lstate.joiner.reset(); - ++completed; - GetNextPair(client, lstate); + if (TryPrepareNextStage()) { + UnblockTasks(guard); } - ProgressData GetProgress() const { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; + if (TryNextTask(task_local)) { + task = task_local; + ++started; + return true; + } - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; + task = nullptr; - const auto count = pair_count + left_outers + right_outers; + return false; +} - const auto l = MinValue(next_left.load(), left_outers.load()); - const auto r = MinValue(next_right.load(), right_outers.load()); - const auto returned = completed.load() + l + r; +bool IEJoinGlobalSourceState::TryNextTask(Task &task) { + if (next_task >= GetTaskCount()) { + return false; + } - ProgressData res; - if (count) { - res.done = double(returned); - res.total = double(count); - } else { - res.SetInvalid(); + // Search for where we are in the task list + for (idx_t stage = idx_t(IEJoinSourceStage::INIT); stage <= idx_t(IEJoinSourceStage::DONE); ++stage) { + if (next_task < stage_begin[stage]) { + task.stage = IEJoinSourceStage(stage - 1); + task.thread_idx = next_task - stage_begin[size_t(task.stage)]; + break; } - return res; } - const PhysicalIEJoin &op; - IEJoinGlobalState &gsink; + if (task.stage != stage) { + return false; + } - bool initialized; + switch (stage) { + case IEJoinSourceStage::SORT_L1: + task.l_range = {0, left_blocks}; + task.r_range = {0, right_blocks}; + break; + case IEJoinSourceStage::MATERIALIZE_L1: + task.l_range = {0, 1}; + task.r_range = {0, 0}; + break; + case IEJoinSourceStage::SORT_L2: + task.l_range = {0, l1->BlockCount()}; + break; + case IEJoinSourceStage::MATERIALIZE_L2: + task.l_range = {0, 1}; + task.r_range = {0, 0}; + break; + case IEJoinSourceStage::INNER: { + task.l_range.first = task.thread_idx * per_thread; + task.l_range.second = MinValue(task.l_range.first + per_thread, join_blocks); + break; + } + case IEJoinSourceStage::OUTER: + if (task.thread_idx < left_outers) { + // Left outer blocks + const auto left_task = task.thread_idx; + task.l_range = {left_task, left_task + 1}; + task.r_range = {0, 0}; + } else { + // Right outer blocks + const auto right_task = task.thread_idx - left_outers; + task.l_range = {0, 0}; + task.r_range = {right_task, right_task + 1}; + } + break; + case IEJoinSourceStage::INIT: + case IEJoinSourceStage::DONE: + break; + } - // Join queue state - atomic next_pair; - atomic completed; + ++next_task; - // Block base row number - vector left_bases; - vector right_bases; + return true; +} - // Outer joins - atomic left_outers; - atomic next_left; +ProgressData IEJoinGlobalSourceState::GetProgress() const { + const auto count = GetTaskCount(); - atomic right_outers; - atomic next_right; -}; + const auto returned = finished.load(); -unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { + ProgressData res; + if (count) { + res.done = double(returned); + res.total = double(count); + } else { + res.SetInvalid(); + } + return res; +} +unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &client) const { auto &gsink = sink_state->Cast(); - return make_uniq(*this, gsink); + return make_uniq(*this, client, gsink); } unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(context.client, *this); + auto &gsource = gstate.Cast(); + return make_uniq(context.client, gsource); } ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { @@ -995,82 +1472,97 @@ ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceSta return gsource.GetProgress(); } -SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - auto &ie_sink = sink_state->Cast(); - auto &ie_gstate = input.global_state.Cast(); - auto &ie_lstate = input.local_state.Cast(); - - ie_gstate.Initialize(); +SourceResultType PhysicalIEJoin::GetDataInternal(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); + auto &lsource = input.local_state.Cast(); - if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_lstate); + // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done + // Therefore, we loop until we've produced tuples, or until the operator is actually done + while (gsource.stage != IEJoinSourceStage::DONE && result.size() == 0) { + if (!lsource.TaskFinished() || lsource.TryAssignTask()) { + lsource.ExecuteTask(context, result, input.interrupt_state); + } else { + auto guard = gsource.Lock(); + if (gsource.TryPrepareNextStage() || gsource.stage == IEJoinSourceStage::DONE) { + gsource.UnblockTasks(guard); + } else { + return gsource.BlockSource(guard, input.interrupt_state); + } + } } + return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} - // Process INNER results - while (ie_lstate.joiner) { - ResolveComplexJoin(context, result, ie_lstate); +void IEJoinLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); - if (result.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } + const auto left_cols = op.children[0].get().GetTypes().size(); + auto &chunk = unprojected; - ie_gstate.PairCompleted(context.client, ie_lstate); + const idx_t count = SelectOuterRows(left_matches); + if (!count) { + left_matches = nullptr; + return; } - // Process LEFT OUTER results - const auto left_cols = children[0].get().GetTypes().size(); - while (ie_lstate.left_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, - count); + auto &left_table = *ie_sink.tables[0]; - // Fill in NULLs to the right - for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { + left_table.Repin(*left_iterator); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, true_sel, count, + *left_scan_state); + + // Fill in NULLs to the right + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::SetNull(chunk.data[col_idx], true); } + } - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); +} + +void IEJoinLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + const auto left_cols = op.children[0].get().GetTypes().size(); + + auto &chunk = unprojected; - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + const idx_t count = SelectOuterRows(right_matches); + if (!count) { + right_matches = nullptr; + return; } - // Process RIGHT OUTER results - while (ie_lstate.right_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } + auto &right_table = *ie_sink.tables[1]; + auto &rsel = true_sel; - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, - count, left_cols); + right_table.Repin(*right_iterator); + op.SliceSortedPayload(rpayload, right_table, *right_iterator, right_chunk_state, right_block_index, rsel, count, + *right_scan_state); - // Fill in NULLs to the left - for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { + // Fill in NULLs to the left + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::SetNull(chunk.data[col_idx], true); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - break; } - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + op.ProjectResult(chunk, result); + result.SetCardinality(count); + result.Verify(); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp index d96cda05d..511571854 100644 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -1,7 +1,5 @@ #include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" #include "duckdb/parallel/thread_context.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/nested_loop_join.hpp" #include "duckdb/main/client_context.hpp" @@ -9,20 +7,22 @@ namespace duckdb { -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info_p) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, - estimated_cardinality) { - + estimated_cardinality), + predicate(std::move(op.predicate)) { filter_pushdown = std::move(pushdown_info_p); children.push_back(left); children.push_back(right); } -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality) : PhysicalNestedLoopJoin(physical_plan, op, left, right, std::move(cond), join_type, estimated_cardinality, nullptr) { @@ -273,7 +273,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, const vector &conditions) : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)) { + left_outer(IsLeftOuterJoin(op.join_type)), pred_executor(context) { vector condition_types; for (auto &cond : conditions) { lhs_executor.AddExpression(*cond.left); @@ -284,6 +284,11 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { right_condition.Initialize(allocator, condition_types); right_payload.Initialize(allocator, op.children[1].get().GetTypes()); left_outer.Initialize(STANDARD_VECTOR_SIZE); + + if (op.predicate) { + pred_executor.AddExpression(*op.predicate); + pred_matches.Initialize(); + } } bool fetch_next_left; @@ -302,6 +307,10 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { OuterJoinMarker left_outer; + //! Predicate + ExpressionExecutor pred_executor; + SelectionVector pred_matches; + public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); @@ -438,11 +447,20 @@ OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext & if (match_count > 0) { // we have matching tuples! // construct the result - state.left_outer.SetMatches(lvector, match_count); - gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); - chunk.Slice(input, lvector, match_count); chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); + + // If we have a predicate, apply it to the result + if (predicate) { + auto &sel = state.pred_matches; + match_count = state.pred_executor.SelectExpression(chunk, sel); + chunk.Slice(sel, match_count); + lvector.SliceInPlace(sel, match_count); + rvector.SliceInPlace(sel, match_count); + } + + state.left_outer.SetMatches(lvector, match_count); + gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); } // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration @@ -494,8 +512,8 @@ unique_ptr PhysicalNestedLoopJoin::GetLocalSourceState(Executi return make_uniq(*this, gstate.Cast()); } -SourceResultType PhysicalNestedLoopJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalNestedLoopJoin::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { D_ASSERT(PropagatesBuildSide(join_type)); // check if we need to scan any unmatched tuples from the RHS for the full/right outer join auto &sink = sink_state->Cast(); diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 1bd48ab62..b70e673ac 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -1,11 +1,8 @@ #include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" #include "duckdb/main/client_context.hpp" @@ -21,7 +18,6 @@ PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(PhysicalPlan &physical_pl unique_ptr pushdown_info_p) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, left, right, std::move(cond), join_type, estimated_cardinality, std::move(pushdown_info_p)) { - for (auto &join_cond : conditions) { D_ASSERT(join_cond.left->return_type == join_cond.right->return_type); join_key_types.push_back(join_cond.left->return_type); @@ -65,15 +61,14 @@ class MergeJoinGlobalState : public GlobalSinkState { using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + MergeJoinGlobalState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) { + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout, op); + table = make_uniq(client, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); - global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + global_filter_state = op.filter_pushdown->GetGlobalState(client, op); } } @@ -81,8 +76,9 @@ class MergeJoinGlobalState : public GlobalSinkState { return table->count; } - void Sink(DataChunk &input, MergeJoinLocalState &lstate); + void Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate); + //! The sorted table unique_ptr table; //! Should we not bother pushing down filters? bool skip_filter_pushdown = false; @@ -92,16 +88,19 @@ class MergeJoinGlobalState : public GlobalSinkState { class MergeJoinLocalState : public LocalSinkState { public: - explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, MergeJoinGlobalState &gstate, - const idx_t child) - : table(context, op, child) { + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + + MergeJoinLocalState(ExecutionContext &context, MergeJoinGlobalState &gstate, const idx_t child) + : table(context, *gstate.table, child) { + auto &op = gstate.table->op; if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); } } //! The local sort state - PhysicalRangeJoin::LocalSortedTable table; + LocalSortedTable table; //! Local state for accumulating filter statistics unique_ptr local_filter_state; }; @@ -113,20 +112,12 @@ unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(Clien unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { // We only sink the RHS auto &gstate = sink_state->Cast(); - return make_uniq(context.client, *this, gstate, 1U); + return make_uniq(context, gstate, 1U); } -void MergeJoinGlobalState::Sink(DataChunk &input, MergeJoinLocalState &lstate) { - auto &global_sort_state = table->global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void MergeJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, @@ -134,7 +125,7 @@ SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataC auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -147,7 +138,7 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.table->Combine(lstate.table); + gstate.table->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -162,25 +153,28 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } - auto &global_sort_state = gstate.table->global_sort_state; + + gstate.table->Finalize(client, input.interrupt_state); if (PropagatesBuildSide(join_type)) { // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple gstate.table->IntializeMatches(); } - if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + + if (gstate.table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! + gstate.table->MaterializeEmpty(client); return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } // Sort the current input child - gstate.table->Finalize(pipeline, event); + gstate.table->Materialize(pipeline, event); return SinkFinalizeType::READY; } @@ -191,46 +185,50 @@ SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event class PiecewiseMergeJoinState : public CachingOperatorState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), - left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), - right_position(0), right_chunk_index(0), rhs_executor(context) { - vector condition_types; - for (auto &order : op.lhs_orders) { - condition_types.push_back(order.expression->return_type); - } + PiecewiseMergeJoinState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) + : client(client), allocator(Allocator::Get(client)), op(op), left_outer(IsLeftOuterJoin(op.join_type)), + left_position(0), first_fetch(true), finished(true), right_position(0), right_chunk_index(0), + rhs_executor(client) { left_outer.Initialize(STANDARD_VECTOR_SIZE); - lhs_layout.Initialize(op.children[0].get().GetTypes()); - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + // Sort on the first column lhs_order.emplace_back(op.lhs_orders[0].Copy()); // Set up shared data for multiple predicates sel.Initialize(STANDARD_VECTOR_SIZE); - condition_types.clear(); + vector condition_types; for (auto &order : op.rhs_orders) { rhs_executor.AddExpression(*order.expression); condition_types.push_back(order.expression->return_type); } - rhs_keys.Initialize(allocator, condition_types); + rhs_keys.Initialize(client, condition_types); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); + + auto &gsink = op.sink_state->Cast(); + auto &rhs_table = *gsink.table; + rhs_iterator = rhs_table.CreateIteratorState(); + rhs_table.InitializePayloadState(rhs_chunk_state); + rhs_scan_state = rhs_table.CreateScanState(client); + + // Since we have now materialized the payload, the keys will not have payloads? + sort_key_type = rhs_table.GetSortKeyType(); } - ClientContext &context; + ClientContext &client; Allocator &allocator; const PhysicalPiecewiseMergeJoin &op; - BufferManager &buffer_manager; - bool force_external; // Block sorting DataChunk lhs_payload; OuterJoinMarker left_outer; vector lhs_order; - RowLayout lhs_layout; + unique_ptr lhs_global_table; unique_ptr lhs_local_table; - unique_ptr lhs_global_state; - unique_ptr scanner; + SortKeyType sort_key_type; + TupleDataScanState lhs_scan; // Simple scans idx_t left_position; @@ -238,178 +236,127 @@ class PiecewiseMergeJoinState : public CachingOperatorState { // Complex scans bool first_fetch; bool finished; + unique_ptr lhs_iterator; + unique_ptr rhs_iterator; idx_t right_position; idx_t right_chunk_index; idx_t right_base; idx_t prev_left_index; + TupleDataChunkState rhs_chunk_state; + unique_ptr rhs_scan_state; // Secondary predicate shared data SelectionVector sel; DataChunk rhs_keys; DataChunk rhs_input; ExpressionExecutor rhs_executor; - vector payload_heap_handles; public: - void ResolveJoinKeys(DataChunk &input) { + void ResolveJoinKeys(ExecutionContext &context, DataChunk &input) { // sort by join key - lhs_global_state = make_uniq(context, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0U); - lhs_local_table->Sink(input, *lhs_global_state); - - // Set external (can be forced with the PRAGMA) - lhs_global_state->external = force_external; - lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); - lhs_global_state->PrepareMergePhase(); - while (lhs_global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*lhs_global_state, buffer_manager); - merge_sorter.PerformInMergeRound(); - lhs_global_state->CompleteMergeRound(); - } - - // Scan the sorted payload - D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); - - scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); - lhs_payload.Reset(); - scanner->Scan(lhs_payload); + const auto &lhs_types = lhs_payload.GetTypes(); + lhs_global_table = make_uniq(context.client, lhs_order, lhs_types, op); + lhs_local_table = make_uniq(context, *lhs_global_table, 0U); + lhs_local_table->Sink(context, input); + lhs_global_table->Combine(context, *lhs_local_table); + + InterruptState interrupt; + lhs_global_table->Finalize(context.client, interrupt); + lhs_global_table->Materialize(context, interrupt); + + // Scan the sorted payload (minus the primary sort column) + auto &lhs_table = *lhs_global_table; + auto &lhs_payload_data = *lhs_table.sorted->payload_data; + lhs_payload_data.InitializeScan(lhs_scan); + lhs_payload_data.Scan(lhs_scan, lhs_payload); // Recompute the sorted keys from the sorted input - lhs_local_table->keys.Reset(); - lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); - } + auto &lhs_keys = lhs_local_table->keys; + lhs_keys.Reset(); + lhs_local_table->executor.Execute(lhs_payload, lhs_keys); - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - if (lhs_local_table) { - context.thread.profiler.Flush(op); - } + lhs_iterator = lhs_table.CreateIteratorState(); } }; unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - bool force_external = ClientConfig::GetConfig(context.client).force_external; - return make_uniq(context.client, *this, force_external); + return make_uniq(context.client, *this); } -static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { - return MinValue(base + count, MaxValue(base, not_null)) - base; +static inline idx_t SortedChunkNotNull(const idx_t chunk_idx, const idx_t count, const idx_t has_null) { + const auto chunk_begin = chunk_idx * STANDARD_VECTOR_SIZE; + const auto chunk_end = MinValue(chunk_begin + STANDARD_VECTOR_SIZE, count); + const auto not_null = count - has_null; + return MinValue(chunk_end, MaxValue(chunk_begin, not_null)) - chunk_begin; } -static int MergeJoinComparisonValue(ExpressionType comparison) { +static bool MergeJoinStrictComparison(ExpressionType comparison) { switch (comparison) { case ExpressionType::COMPARE_LESSTHAN: case ExpressionType::COMPARE_GREATERTHAN: - return -1; + return true; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; + return false; default: throw InternalException("Unimplemented comparison type for merge join!"); } } -struct BlockMergeInfo { - GlobalSortState &state; - //! The block being scanned - const idx_t block_idx; - //! The number of not-NULL values in the block (they are at the end) - const idx_t not_null; - //! The current offset in the block - idx_t &entry_idx; - SelectionVector result; - - BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) - : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { - } -}; - -static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { - scan.SetIndices(block_idx, 0); - scan.PinRadix(block_idx); - - auto &sd = *scan.sb->blob_sorting_data; - if (block_idx < sd.data_blocks.size()) { - scan.PinData(sd); +// Compare using +bool MergeJoinBefore(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); } + return less_than; } -static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { - scan.entry_idx = entry_idx; - return scan.RadixPtr(); -} +template +static idx_t TemplatedMergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, + bool *found_match, const bool strict) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; -static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, - const ExpressionType comparison) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - auto &lsort = *lstate.lhs_global_state; - auto &rsort = rstate.table->global_sort_state; - D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); - const auto all_constant = lsort.sort_layout.all_constant; - D_ASSERT(lsort.external == rsort.external); - const auto external = lsort.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(lsort.sorted_blocks.size() == 1); - SBScanState lread(lsort.buffer_manager, lsort); - lread.sb = lsort.sorted_blocks[0].get(); - - const idx_t l_block_idx = 0; - idx_t l_entry_idx = 0; - const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; - MergeJoinPinSortingBlock(lread, l_block_idx); - auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); - - D_ASSERT(rsort.sorted_blocks.size() == 1); - SBScanState rread(rsort.buffer_manager, rsort); - rread.sb = rsort.sorted_blocks[0].get(); + // We only need the keys because we are extracting the row numbers + auto &lhs_table = *lstate.lhs_global_table; + D_ASSERT(SORT_KEY_TYPE == lhs_table.GetSortKeyType()); + auto &lhs_iterator = *lstate.lhs_iterator; + const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - const auto cmp_size = lsort.sort_layout.comparison_size; - const auto entry_size = lsort.sort_layout.entry_size; + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *lstate.rhs_iterator; + const auto rhs_not_null = rhs_table.count - rhs_table.has_null; - idx_t right_base = 0; - for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { - // we only care about the BIGGEST value in each of the RHS data blocks + idx_t l_entry_idx = 0; + BLOCK_ITERATOR lhs_itr(lhs_iterator); + BLOCK_ITERATOR rhs_itr(rhs_iterator); + for (idx_t r_idx = 0; r_idx < rhs_not_null; r_idx += STANDARD_VECTOR_SIZE) { + // Repin the RHS to release memory + // This is safe because we only return the LHS values + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); + + // we only care about the BIGGEST value in the RHS // because we want to figure out if the LHS values are less than [or equal] to ANY value - // get the biggest value from the RHS chunk - MergeJoinPinSortingBlock(rread, r_block_idx); - - auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; - const auto r_not_null = - SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); - if (r_not_null == 0) { - break; - } - const auto r_entry_idx = r_not_null - 1; - right_base += rblock.count; - - auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); + const auto r_entry_idx = MinValue(r_idx + STANDARD_VECTOR_SIZE, rhs_not_null) - 1; // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS // value while (true) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l_entry_idx; - rread.entry_idx = r_entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); - } - - if (comp_res <= cmp) { + // Note that both subscripts here are table indices, not chunk indices. + if (MergeJoinBefore(lhs_itr[l_entry_idx], rhs_itr[r_entry_idx], strict)) { // found a match for lpos, set it in the found_match vector found_match[l_entry_idx] = true; l_entry_idx++; - l_ptr += entry_size; if (l_entry_idx >= lhs_not_null) { // early out: we exhausted the entire LHS and they all match return 0; } } else { // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not - // match move to the next RHS chunk + // match. Move to the next RHS chunk break; } } @@ -417,13 +364,42 @@ static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlo return 0; } +static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, bool *match, + const ExpressionType comparison) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (lstate.sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(lstate.sort_key_type)); + } +} + void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - state.ResolveJoinKeys(input); - auto &lhs_table = *state.lhs_local_table; + state.ResolveJoinKeys(context, input); + auto &lhs_table = *state.lhs_global_table; + auto &lhs_keys = state.lhs_local_table->keys; // perform the actual join bool found_match[STANDARD_VECTOR_SIZE]; @@ -439,8 +415,8 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da case JoinType::MARK: { // The only part of the join keys that is actually used is the validity mask. // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. - for (auto &key : lhs_table.keys.data) { - key.Flatten(lhs_table.keys.size()); + for (auto &key : lhs_keys.data) { + key.Flatten(lhs_keys.size()); auto &mask = FlatVector::Validity(key); if (mask.AllValid()) { continue; @@ -451,7 +427,7 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } // So we make a set of keys that have the validity mask set for the - PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); + PhysicalJoin::ConstructMarkJoinResult(lhs_keys, payload, chunk, found_match, gstate.table->has_null); break; } case JoinType::SEMI: @@ -465,40 +441,40 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } -static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, - idx_t &prev_left_index) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); - const auto all_constant = r.state.sort_layout.all_constant; - D_ASSERT(l.state.external == r.state.external); - const auto external = l.state.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(l.state.sorted_blocks.size() == 1); - SBScanState lread(l.state.buffer_manager, l.state); - lread.sb = l.state.sorted_blocks[0].get(); - D_ASSERT(lread.sb->radix_sorting_data.size() == 1); - MergeJoinPinSortingBlock(lread, l.block_idx); - auto l_start = MergeJoinRadixPtr(lread, 0); - auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); - - D_ASSERT(r.state.sorted_blocks.size() == 1); - SBScanState rread(r.state.buffer_manager, r.state); - rread.sb = r.state.sorted_blocks[0].get(); +struct ChunkMergeInfo { + //! The iteration state + ExternalBlockIteratorState &state; + //! The block being scanned + const idx_t block_idx; + //! The number of not-NULL values in the chunk (they are at the end) + const idx_t not_null; + //! The current offset in the chunk + idx_t &entry_idx; + //! The offsets that match + SelectionVector result; - if (r.entry_idx >= r.not_null) { - return 0; + ChunkMergeInfo(ExternalBlockIteratorState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) + : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { } - MergeJoinPinSortingBlock(rread, r.block_idx); - auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); + idx_t GetIndex() const { + return state.GetIndex(block_idx, entry_idx); + } +}; - const auto cmp_size = l.state.sort_layout.comparison_size; - const auto entry_size = l.state.sort_layout.entry_size; +template +static idx_t TemplatedMergeJoinComplexBlocks(ChunkMergeInfo &l, ChunkMergeInfo &r, const bool strict, + idx_t &prev_left_index) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + + if (r.entry_idx >= r.not_null) { + return 0; + } idx_t result_count = 0; + BLOCK_ITERATOR l_ptr(l.state); + BLOCK_ITERATOR r_ptr(r.state); while (true) { if (l.entry_idx < prev_left_index) { // left side smaller: found match @@ -507,7 +483,7 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -515,22 +491,14 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const continue; } if (l.entry_idx < l.not_null) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l.entry_idx; - rread.entry_idx = r.entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); - } - if (comp_res <= cmp) { + if (MergeJoinBefore(l_ptr[l.GetIndex()], r_ptr[r.GetIndex()], strict)) { // left side smaller: found match l.result.set_index(result_count, sel_t(l.entry_idx)); r.result.set_index(result_count, sel_t(r.entry_idx)); result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -546,27 +514,53 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const if (r.entry_idx >= r.not_null) { break; } - r_ptr += entry_size; + ++r_ptr; - l_ptr = l_start; l.entry_idx = 0; } return result_count; } +static idx_t MergeJoinComplexBlocks(const SortKeyType &sort_key_type, ChunkMergeInfo &l, ChunkMergeInfo &r, + const ExpressionType comparison, idx_t &prev_left_index) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); + } +} + OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; const auto left_cols = input.ColumnCount(); const auto tail_cols = conditions.size() - 1; - state.payload_heap_handles.clear(); do { if (state.first_fetch) { - state.ResolveJoinKeys(input); + state.ResolveJoinKeys(context, input); + state.lhs_payload.Verify(); state.right_chunk_index = 0; state.right_base = 0; @@ -588,36 +582,44 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte return OperatorResultType::NEED_MORE_INPUT; } - auto &lhs_table = *state.lhs_local_table; + auto &lhs_table = *state.lhs_global_table; const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); + ChunkMergeInfo left_info(*state.lhs_iterator, 0, state.left_position, lhs_not_null); + + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *state.rhs_iterator; + const auto rhs_not_null = SortedChunkNotNull(state.right_chunk_index, rhs_table.count, rhs_table.has_null); + ChunkMergeInfo right_info(rhs_iterator, state.right_chunk_index, state.right_position, rhs_not_null); - const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; - const auto rhs_not_null = - SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); - BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, - rhs_not_null); + // Repin so we don't hang on to data after we have scanned it + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); - idx_t result_count = - MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); + idx_t result_count = MergeJoinComplexBlocks(state.sort_key_type, left_info, right_info, + conditions[0].comparison, state.prev_left_index); if (result_count == 0) { // exhausted this chunk on the right side // move to the next right chunk state.left_position = 0; state.right_position = 0; - state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; + state.right_base += STANDARD_VECTOR_SIZE; state.right_chunk_index++; - if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { + if (state.right_chunk_index >= rhs_table.BlockCount()) { state.finished = true; } } else { // found matches: extract them + SliceSortedPayload(state.rhs_input, rhs_table, rhs_iterator, state.rhs_chunk_state, right_info.block_idx, + right_info.result, result_count, *state.rhs_scan_state); + chunk.Reset(); - for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { - chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Slice(state.lhs_payload.data[col_idx], left_info.result, result_count); + } else { + chunk.data[col_idx].Reference(state.rhs_input.data[col_idx - left_cols]); + } } - state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, - right_info.result, result_count, left_cols)); chunk.SetCardinality(result_count); auto sel = FlatVector::IncrementalSelectionVector(); @@ -625,13 +627,12 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte // If there are more expressions to compute, // split the result chunk into the left and right halves // so we can compute the values for comparison. - chunk.Split(state.rhs_input, left_cols); state.rhs_executor.SetChunk(state.rhs_input); state.rhs_keys.Reset(); auto tail_count = result_count; for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { - Vector left(lhs_table.keys.data[cmp_idx]); + Vector left(state.lhs_local_table->keys.data[cmp_idx]); left.Slice(left_info.result, result_count); auto &right = state.rhs_keys.data[cmp_idx]; @@ -645,7 +646,6 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); sel = &state.sel; } - chunk.Fuse(state.rhs_input); if (tail_count < result_count) { result_count = tail_count; @@ -713,54 +713,78 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -class PiecewiseJoinScanState : public GlobalSourceState { +class PiecewiseJoinGlobalScanState : public GlobalSourceState { +public: + explicit PiecewiseJoinGlobalScanState(TupleDataCollection &payload) : payload(payload), right_outer_position(0) { + payload.InitializeScan(parallel_scan); + } + + idx_t Scan(TupleDataLocalScanState &local_scan, DataChunk &chunk) { + lock_guard guard(lock); + const auto result = right_outer_position; + payload.Scan(parallel_scan, local_scan, chunk); + right_outer_position += chunk.size(); + return result; + } + + TupleDataCollection &payload; + public: - explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { + idx_t MaxThreads() override { + return payload.ChunkCount(); } +private: mutex lock; - const PhysicalPiecewiseMergeJoin &op; - unique_ptr scanner; + TupleDataParallelScanState parallel_scan; idx_t right_outer_position; +}; +class PiecewiseJoinLocalScanState : public LocalSourceState { public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); + explicit PiecewiseJoinLocalScanState(PiecewiseJoinGlobalScanState &gstate) : rsel(STANDARD_VECTOR_SIZE) { + gstate.payload.InitializeScan(scanner); + gstate.payload.InitializeChunk(rhs_chunk); } + + TupleDataLocalScanState scanner; + DataChunk rhs_chunk; + SelectionVector rsel; }; unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); + auto &gsink = sink_state->Cast(); + return make_uniq(*gsink.table->sorted->payload_data); } -SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { +unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(gstate.Cast()); +} + +SourceResultType PhysicalPiecewiseMergeJoin::GetDataInternal(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &source) const { D_ASSERT(PropagatesBuildSide(join_type)); // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &state = input.global_state.Cast(); - - lock_guard l(state.lock); - if (!state.scanner) { - // Initialize scanner (if not yet initialized) - auto &sort_state = sink.table->global_sort_state; - if (sort_state.sorted_blocks.empty()) { - return SourceResultType::FINISHED; - } - state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); + auto &gsink = sink_state->Cast(); + auto &gsource = source.global_state.Cast(); + + // RHS was empty, so nothing to do? + if (!gsink.table->count) { + return SourceResultType::FINISHED; } // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we // still need to output - const auto found_match = sink.table->found_match.get(); + const auto found_match = gsink.table->found_match.get(); - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); + auto &lsource = source.local_state.Cast(); + auto &rhs_chunk = lsource.rhs_chunk; + auto &rsel = lsource.rsel; for (;;) { // Read the next sorted chunk - state.scanner->Scan(rhs_chunk); + rhs_chunk.Reset(); + const auto rhs_pos = gsource.Scan(lsource.scanner, rhs_chunk); const auto count = rhs_chunk.size(); if (count == 0) { @@ -770,11 +794,10 @@ SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, idx_t result_count = 0; // figure out which tuples didn't find a match in the RHS for (idx_t i = 0; i < count; i++) { - if (!found_match[state.right_outer_position + i]) { + if (!found_match[rhs_pos + i]) { rsel.set_index(result_count++, i); } } - state.right_outer_position += count; if (result_count > 0) { // if there were any tuples that didn't find a match, output them diff --git a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp index ab4a15091..35f759599 100644 --- a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp @@ -171,8 +171,8 @@ void PositionalJoinGlobalState::GetData(DataChunk &output) { output.SetCardinality(count); } -SourceResultType PhysicalPositionalJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { +SourceResultType PhysicalPositionalJoin::GetDataInternal(ExecutionContext &context, DataChunk &result, + OperatorSourceInput &input) const { auto &sink = sink_state->Cast(); sink.GetData(result); diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index 4fefafbd4..41abaeca9 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -1,10 +1,7 @@ #include "duckdb/execution/operator/join/physical_range_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/unordered_map.hpp" @@ -14,15 +11,15 @@ #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/parallel/executor_task.hpp" - -#include +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { -PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, +PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child) - : op(op), executor(context), has_null(0), count(0) { + : global_table(global_table), executor(context.client), has_null(0), count(0) { // Initialize order clause expression executor and key DataChunk + const auto &op = global_table.op; vector types; for (const auto &cond : op.conditions) { const auto &expr = child ? cond.right : cond.left; @@ -30,16 +27,19 @@ PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, co types.push_back(expr->return_type); } - auto &allocator = Allocator::Get(context); + auto &allocator = Allocator::Get(context.client); keys.Initialize(allocator, types); -} -void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); - } + local_sink = global_table.sort->GetLocalSinkState(context); + + // Only sort the primary key + types.resize(1); + const auto &payload_types = op.children[child].get().types; + types.insert(types.end(), payload_types.begin(), payload_types.end()); + sort_chunk.InitializeEmpty(types); +} +void PhysicalRangeJoin::LocalSortedTable::Sink(ExecutionContext &context, DataChunk &input) { // Obtain sorting columns keys.Reset(); executor.Execute(input, keys); @@ -47,121 +47,179 @@ void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState // Do not operate on primary key directly to avoid modifying the input chunk Vector primary = keys.data[0]; // Count the NULLs so we can exclude them later - has_null += MergeNulls(primary, op.conditions); + has_null += MergeNulls(primary, global_table.op.conditions); count += keys.size(); // Only sort the primary key - DataChunk join_head; - join_head.data.emplace_back(primary); - join_head.SetCardinality(keys.size()); + sort_chunk.data[0].Reference(primary); + for (column_t col_idx = 0; col_idx < input.ColumnCount(); ++col_idx) { + sort_chunk.data[col_idx + 1].Reference(input.data[col_idx]); + } + sort_chunk.SetCardinality(input); // Sink the data into the local sort state - local_sort_state.SinkChunk(join_head, input); + InterruptState interrupt; + OperatorSinkInput sink {*global_table.global_sink, *local_sink, interrupt}; + global_table.sort->Sink(context, sort_chunk, sink); } -PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout, const PhysicalOperator &op_p) - : op(op_p), global_sort_state(context, orders, payload_layout), has_null(0), count(0), memory_per_thread(0) { +PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &client, + const vector &order_bys, + const vector &payload_types, + const PhysicalRangeJoin &op) + : op(op), has_null(0), count(0), tasks_completed(0) { + // Set up the sort. We will materialize keys ourselves, so just set up references. + vector orders; + vector input_types; + for (const auto &order_by : order_bys) { + auto order = order_by.Copy(); + const auto type = order.expression->return_type; + input_types.emplace_back(type); + order.expression = make_uniq(type, orders.size()); + orders.emplace_back(std::move(order)); + } + + vector projection_map; + for (const auto &type : payload_types) { + projection_map.emplace_back(input_types.size()); + input_types.emplace_back(type); + } + + sort = make_uniq(client, orders, input_types, projection_map); - // Set external (can be forced with the PRAGMA) - global_sort_state.external = ClientConfig::GetConfig(context).force_external; - memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); + global_sink = sort->GetGlobalSinkState(client); } -void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { - global_sort_state.AddLocalState(ltable.local_sort_state); +void PhysicalRangeJoin::GlobalSortedTable::Combine(ExecutionContext &context, LocalSortedTable <able) { + InterruptState interrupt; + OperatorSinkCombineInput combine {*global_sink, *ltable.local_sink, interrupt}; + sort->Combine(context, combine); has_null += ltable.has_null; count += ltable.count; } +void PhysicalRangeJoin::GlobalSortedTable::Finalize(ClientContext &client, InterruptState &interrupt) { + OperatorSinkFinalizeInput finalize {*global_sink, interrupt}; + sort->Finalize(client, finalize); +} + void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { found_match = make_unsafe_uniq_array_uninitialized(Count()); memset(found_match.get(), 0, sizeof(bool) * Count()); } +void PhysicalRangeJoin::GlobalSortedTable::MaterializeEmpty(ClientContext &client) { + D_ASSERT(!sorted); + sorted = make_uniq(client, *sort, false); +} + void PhysicalRangeJoin::GlobalSortedTable::Print() { - global_sort_state.Print(); + D_ASSERT(sorted); + auto &collection = *sorted->payload_data; + TupleDataScanState scanner; + collection.InitializeScan(scanner); + + DataChunk payload; + collection.InitializeScanChunk(scanner, payload); + + while (collection.Scan(scanner, payload)) { + payload.Print(); + } } -class RangeJoinMergeTask : public ExecutorTask { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeTask +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeTask : public ExecutorTask { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context, std::move(event_p), table.op), context(context), table(table) { + RangeJoinMaterializeTask(Pipeline &pipeline, shared_ptr event, ClientContext &client, + GlobalSortedTable &table, idx_t tasks_scheduled) + : ExecutorTask(client, std::move(event), table.op), pipeline(pipeline), table(table), + tasks_scheduled(tasks_scheduled) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize iejoin sorted and iterate until done - auto &global_sort_state = table.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); + ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); + auto &sort = *table.sort; + auto &sort_global = *table.global_source; + auto sort_local = sort.GetLocalSourceState(execution, sort_global); + InterruptState interrupt((weak_ptr(shared_from_this()))); + OperatorSourceInput input {sort_global, *sort_local, interrupt}; + sort.MaterializeSortedRun(execution, input); + if (++table.tasks_completed == tasks_scheduled) { + table.sorted = sort.GetSortedRun(sort_global); + if (!table.sorted) { + table.MaterializeEmpty(execution.client); + } + } + event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; } string TaskType() const override { - return "RangeJoinMergeTask"; + return "RangeJoinMaterializeTask"; } private: - ClientContext &context; + Pipeline &pipeline; GlobalSortedTable &table; + const idx_t tasks_scheduled; }; -class RangeJoinMergeEvent : public BasePipelineEvent { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeEvent +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeEvent : public BasePipelineEvent { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), table(table_p) { + RangeJoinMaterializeEvent(GlobalSortedTable &table, Pipeline &pipeline) + : BasePipelineEvent(pipeline), table(table) { } GlobalSortedTable &table; public: void Schedule() override { - auto &context = pipeline->GetClientContext(); + auto &client = pipeline->GetClientContext(); - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); + // Schedule as many tasks as the sort will allow + auto &ts = TaskScheduler::GetScheduler(client); auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> iejoin_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); + vector> tasks; + + auto &sort = *table.sort; + auto &global_sink = *table.global_sink; + table.global_source = sort.GetGlobalSourceState(client, global_sink); + const auto tasks_scheduled = MinValue(num_threads, table.global_source->MaxThreads()); + for (idx_t tnum = 0; tnum < tasks_scheduled; ++tnum) { + tasks.push_back( + make_uniq(*pipeline, shared_from_this(), client, table, tasks_scheduled)); } - SetTasks(std::move(iejoin_tasks)); - } - void FinishEvent() override { - auto &global_sort_state = table.global_sort_state; - - global_sort_state.CompleteMergeRound(true); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - table.ScheduleMergeTasks(*pipeline, *this); - } + SetTasks(std::move(tasks)); } }; -void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { - // Initialize global sort state for a round of merging - global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(*this, pipeline); - event.InsertEvent(std::move(new_event)); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(Pipeline &pipeline, Event &event) { + // Schedule all the sorts for maximum thread utilisation + auto sort_event = make_shared_ptr(*this, pipeline); + event.InsertEvent(std::move(sort_event)); } -void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - ScheduleMergeTasks(pipeline, event); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(ExecutionContext &context, InterruptState &interrupt) { + global_source = sort->GetGlobalSourceState(context.client, *global_sink); + auto local_source = sort->GetLocalSourceState(context, *global_source); + OperatorSourceInput source {*global_source, *local_source, interrupt}; + sort->MaterializeSortedRun(context, source); + sorted = sort->GetSortedRun(*global_source); + if (!sorted) { + MaterializeEmpty(context.client); } } @@ -336,56 +394,74 @@ void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const result.SetCardinality(chunk); } -BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols) { - // There should only be one sorted block if they have been sorted - D_ASSERT(state.sorted_blocks.size() == 1); - SBScanState read_state(state.buffer_manager, state); - read_state.sb = state.sorted_blocks[0].get(); - auto &sorted_data = *read_state.sb->payload_data; - - read_state.SetIndices(block_idx, 0); - read_state.PinData(sorted_data); - const auto data_ptr = read_state.DataPtr(sorted_data); - data_ptr_t heap_ptr = nullptr; - - // Set up a batch of pointers to scan data from - Vector addresses(LogicalType::POINTER, result_count); - auto data_pointers = FlatVector::GetData(addresses); - - // Set up the data pointers for the values that are actually referenced - const idx_t &row_width = sorted_data.layout.GetRowWidth(); - - auto prev_idx = result.get_index(0); - SelectionVector gsel(result_count); - idx_t addr_count = 0; - gsel.set_index(0, addr_count); - data_pointers[addr_count] = data_ptr + prev_idx * row_width; - for (idx_t i = 1; i < result_count; ++i) { - const auto row_idx = result.get_index(i); - if (row_idx != prev_idx) { - data_pointers[++addr_count] = data_ptr + row_idx * row_width; - prev_idx = row_idx; - } - gsel.set_index(i, addr_count); +template +static void TemplatedSliceSortedPayload(DataChunk &chunk, const SortedRun &sorted_run, + ExternalBlockIteratorState &state, Vector &sort_key_pointers, + SortedRunScanState &scan_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = state.GetIndex(chunk_idx, result.get_index(i)); + sort_keys[i] = &itr[idx]; } - ++addr_count; - // Unswizzle the offsets back to pointers (if needed) - if (!sorted_data.layout.AllConstant() && state.external) { - heap_ptr = read_state.payload_heap_handle.Ptr(); - } + // Scan + chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, chunk); +} - // Deserialize the payload data - auto sel = FlatVector::IncrementalSelectionVector(); - for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { - auto &col = payload.data[left_cols + col_no]; - RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); - col.Slice(gsel, result_count); +void PhysicalRangeJoin::SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, + ExternalBlockIteratorState &state, TupleDataChunkState &chunk_state, + const idx_t chunk_idx, SelectionVector &result, const idx_t result_count, + SortedRunScanState &scan_state) { + auto &sorted = *table.sorted; + auto &sort_keys = chunk_state.row_locations; + const auto sort_key_type = table.GetSortKeyType(); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); } - - return std::move(read_state.payload_heap_handle); } idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, diff --git a/src/duckdb/src/execution/operator/order/physical_order.cpp b/src/duckdb/src/execution/operator/order/physical_order.cpp index de4e6acaa..90791392a 100644 --- a/src/duckdb/src/execution/operator/order/physical_order.cpp +++ b/src/duckdb/src/execution/operator/order/physical_order.cpp @@ -114,7 +114,8 @@ unique_ptr PhysicalOrder::GetGlobalSourceState(ClientContext return make_uniq(context, sink_state->Cast()); } -SourceResultType PhysicalOrder::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalOrder::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); OperatorSourceInput sort_input {*gstate.state, *lstate.state, input.interrupt_state}; diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp index ec082601c..2620729db 100644 --- a/src/duckdb/src/execution/operator/order/physical_top_n.cpp +++ b/src/duckdb/src/execution/operator/order/physical_top_n.cpp @@ -1,6 +1,7 @@ #include "duckdb/execution/operator/order/physical_top_n.hpp" #include "duckdb/common/assert.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/storage/data_table.hpp" @@ -85,7 +86,8 @@ class TopNHeap { Allocator &allocator; BufferManager &buffer_manager; - unsafe_vector heap; + ArenaAllocator arena_allocator; + unsafe_arena_vector heap; const vector &payload_types; const vector &orders; vector modifiers; @@ -162,10 +164,11 @@ class TopNHeap { //===--------------------------------------------------------------------===// TopNHeap::TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types_p, const vector &orders_p, idx_t limit, idx_t offset) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), payload_types(payload_types_p), - orders(orders_p), limit(limit), offset(offset), heap_size(limit + offset), executor(context), - sort_key_heap(allocator), matching_sel(STANDARD_VECTOR_SIZE), final_sel(STANDARD_VECTOR_SIZE), - true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), new_remaining_sel(STANDARD_VECTOR_SIZE) { + : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), arena_allocator(allocator), + heap(arena_allocator), payload_types(payload_types_p), orders(orders_p), limit(limit), offset(offset), + heap_size(limit + offset), executor(context), sort_key_heap(allocator), matching_sel(STANDARD_VECTOR_SIZE), + final_sel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE), + new_remaining_sel(STANDARD_VECTOR_SIZE) { // initialize the executor and the sort_chunk vector sort_types; for (auto &order : orders) { @@ -575,7 +578,8 @@ unique_ptr PhysicalTopN::GetLocalSourceState(ExecutionContext return make_uniq(); } -SourceResultType PhysicalTopN::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalTopN::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { if (limit == 0) { return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 31e73b41f..d9dddd069 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -644,8 +644,8 @@ unique_ptr PhysicalBatchCopyToFile::GetGlobalSinkState(ClientCo //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalBatchCopyToFile::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &g = sink_state->Cast(); auto fp = use_tmp_file ? PhysicalCopyToFile::GetNonTmpFile(context.client, file_path) : file_path; switch (return_type) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index 95b519d4d..9fb8be6d8 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -90,7 +90,7 @@ class CollectionMerger { auto &collection = data_table.GetOptimisticCollection(context, collection_indexes[i]); TableScanState scan_state; scan_state.Initialize(column_ids); - collection.collection->InitializeScan(scan_state.local_state, column_ids, nullptr); + collection.collection->InitializeScan(context, scan_state.local_state, column_ids, nullptr); while (true) { scan_chunk.Reset(); @@ -194,7 +194,10 @@ class BatchInsertLocalState : public LocalSinkState { void CreateNewCollection(ClientContext &context, DuckTableEntry &table_entry, const vector &insert_types) { - auto collection = OptimisticDataWriter::CreateCollection(table_entry.GetStorage(), insert_types); + if (!optimistic_writer) { + optimistic_writer = make_uniq(context, table_entry.GetStorage()); + } + auto collection = optimistic_writer->CreateCollection(table_entry.GetStorage(), insert_types); auto &row_collection = *collection->collection; row_collection.InitializeEmpty(); row_collection.InitializeAppend(current_append_state); @@ -526,9 +529,6 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &i lock_guard l(gstate.lock); // no collection yet: create a new one lstate.CreateNewCollection(context.client, table, insert_types); - if (!lstate.optimistic_writer) { - lstate.optimistic_writer = make_uniq(context.client, table.GetStorage()); - } } if (lstate.current_index != batch_index) { @@ -689,8 +689,8 @@ SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalBatchInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalBatchInsert::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &insert_gstate = sink_state->Cast(); chunk.SetCardinality(1); diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp index 9b16d476c..f0c5952f3 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp @@ -28,8 +28,8 @@ PhysicalCopyDatabase::~PhysicalCopyDatabase() { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCopyDatabase::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCopyDatabase::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = Catalog::GetCatalog(context.client, info->target_database); for (auto &create_info : info->entries) { switch (create_info->type) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp index 5958a9885..740819261 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -470,7 +470,6 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_file_path) { auto &fs = FileSystem::GetFileSystem(context); auto file_path = GetNonTmpFile(context, tmp_file_path); - fs.TryRemoveFile(file_path); fs.MoveFile(tmp_file_path, file_path); } @@ -704,8 +703,8 @@ void PhysicalCopyToFile::ReturnStatistics(DataChunk &chunk, idx_t row_idx, CopyT chunk.SetValue(5, row_idx, info.partition_keys); } -SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCopyToFile::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &g = sink_state->Cast(); if (return_type == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS) { auto &source_state = input.global_state.Cast(); diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp index 13458c923..eaef83502 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp @@ -27,7 +27,6 @@ class DeleteGlobalState : public GlobalSinkState { explicit DeleteGlobalState(ClientContext &context, const vector &return_types, TableCatalogEntry &table, const vector> &bound_constraints) : deleted_count(0), return_collection(context, return_types), has_unique_indexes(false) { - // We need to append deletes to the local delete-ART. auto &storage = table.GetStorage(); if (storage.HasUniqueIndexes()) { @@ -213,8 +212,8 @@ unique_ptr PhysicalDelete::GetGlobalSourceState(ClientContext return make_uniq(*this); } -SourceResultType PhysicalDelete::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalDelete::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &state = input.global_state.Cast(); auto &g = sink_state->Cast(); if (!return_chunk) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_export.cpp b/src/duckdb/src/execution/operator/persistent/physical_export.cpp index 67e6ab40c..343ce5a47 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_export.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_export.cpp @@ -205,8 +205,8 @@ catalog_entry_vector_t PhysicalExport::GetNaiveExportOrder(ClientContext &contex return catalog_entries; } -SourceResultType PhysicalExport::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalExport::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &state = input.global_state.Cast(); if (state.finished) { return SourceResultType::FINISHED; diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 97c31c4ba..b0b644ca6 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -36,7 +36,6 @@ PhysicalInsert::PhysicalInsert(PhysicalPlan &physical_plan, vector set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), set_types(std::move(set_types)), on_conflict_condition(std::move(on_conflict_condition_p)), do_update_condition(std::move(do_update_condition_p)), conflict_target(std::move(conflict_target_p)), update_is_del_and_insert(update_is_del_and_insert) { - if (action_type == OnConflictAction::THROW) { return; } @@ -82,7 +81,6 @@ InsertGlobalState::InsertGlobalState(ClientContext &context, const vector &types, const vector> &bound_constraints) : collection_index(DConstants::INVALID_INDEX), bound_constraints(bound_constraints) { - auto &allocator = Allocator::Get(context); update_chunk.Initialize(allocator, types); append_chunk.Initialize(allocator, types); @@ -189,7 +187,6 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, DataChunk &update_chunk, const PhysicalInsert &op) { - auto &do_update_condition = op.do_update_condition; auto &set_types = op.set_types; auto &set_expressions = op.set_expressions; @@ -651,14 +648,14 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &insert D_ASSERT(!return_chunk); auto &data_table = gstate.table.GetStorage(); if (!lstate.collection_index.IsValid()) { + lock_guard l(gstate.lock); + lstate.optimistic_writer = make_uniq(context.client, data_table); // Create the local row group collection. - auto optimistic_collection = OptimisticDataWriter::CreateCollection(storage, insert_types); + auto optimistic_collection = lstate.optimistic_writer->CreateCollection(storage, insert_types); auto &collection = *optimistic_collection->collection; collection.InitializeEmpty(); collection.InitializeAppend(lstate.local_append_state); - lock_guard l(gstate.lock); - lstate.optimistic_writer = make_uniq(context.client, data_table); lstate.collection_index = data_table.CreateOptimisticCollection(context.client, std::move(optimistic_collection)); } @@ -748,8 +745,8 @@ unique_ptr PhysicalInsert::GetGlobalSourceState(ClientContext return make_uniq(*this); } -SourceResultType PhysicalInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalInsert::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &state = input.global_state.Cast(); auto &insert_gstate = sink_state->Cast(); if (!return_chunk) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp index 04a5f3dca..672a9b861 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp @@ -10,7 +10,6 @@ PhysicalMergeInto::PhysicalMergeInto(PhysicalPlan &physical_plan, vector ranges; for (auto &entry : actions_p) { MergeActionRange range; @@ -456,8 +455,8 @@ unique_ptr PhysicalMergeInto::GetLocalSourceState(ExecutionCon return make_uniq(context, *this, gstate.Cast()); } -SourceResultType PhysicalMergeInto::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalMergeInto::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &g = sink_state->Cast(); if (!return_chunk) { chunk.SetCardinality(1); diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp index f96dba699..3e698cf02 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -25,7 +25,6 @@ PhysicalUpdate::PhysicalUpdate(PhysicalPlan &physical_plan, vector tableref(tableref), table(table), columns(std::move(columns)), expressions(std::move(expressions)), bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints)), return_chunk(return_chunk), index_update(false) { - auto &indexes = table.GetDataTableInfo().get()->GetIndexes(); auto index_columns = indexes.GetRequiredColumns(); @@ -67,7 +66,6 @@ class UpdateLocalState : public LocalSinkState { const vector &table_types, const vector> &bound_defaults, const vector> &bound_constraints) : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { - // Initialize the update chunk. auto &allocator = Allocator::Get(context); vector update_types; @@ -244,8 +242,8 @@ unique_ptr PhysicalUpdate::GetGlobalSourceState(ClientContext return make_uniq(*this); } -SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalUpdate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &state = input.global_state.Cast(); auto &g = sink_state->Cast(); if (!return_chunk) { diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp index ffd9ec565..21a111f68 100644 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -9,7 +9,6 @@ PhysicalPivot::PhysicalPivot(PhysicalPlan &physical_plan, vector ty BoundPivotInfo bound_pivot_p) : PhysicalOperator(physical_plan, PhysicalOperatorType::PIVOT, std::move(types_p), child.estimated_cardinality), bound_pivot(std::move(bound_pivot_p)) { - children.push_back(child); for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { auto entry = pivot_map.find(bound_pivot.pivot_values[p]); @@ -22,12 +21,12 @@ PhysicalPivot::PhysicalPivot(PhysicalPlan &physical_plan, vector ty for (auto &aggr_expr : bound_pivot.aggregates) { auto &aggr = aggr_expr->Cast(); // for each aggregate, initialize an empty aggregate state and finalize it immediately - auto state = make_unsafe_uniq_array(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, state.get()); + auto state = make_unsafe_uniq_array(aggr.function.GetStateSizeCallback()(aggr.function)); + aggr.function.GetStateInitCallback()(aggr.function, state.get()); Vector state_vector(Value::POINTER(CastPointerToValue(state.get()))); Vector result_vector(aggr_expr->return_type); AggregateInputData aggr_input_data(aggr.bind_info.get(), physical_plan.ArenaRef()); - aggr.function.finalize(state_vector, aggr_input_data, result_vector, 1, 0); + aggr.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, result_vector, 1, 0); empty_aggregates.push_back(result_vector.GetValue(0)); } } diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp index f464bfb18..de0e7ff99 100644 --- a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp @@ -19,6 +19,14 @@ class TableInOutGlobalState : public GlobalOperatorState { TableInOutGlobalState() { } + idx_t MaxThreads(idx_t source_max_threads) override { + // If no state assume maximum parallelism as the source. + if (!global_state) { + return source_max_threads; + } + return global_state->MaxThreads(); + } + unique_ptr global_state; }; diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp index 62164d95b..e1ce4bb05 100644 --- a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp @@ -13,7 +13,6 @@ class UnnestOperatorState : public OperatorState { public: UnnestOperatorState(ClientContext &context, const vector> &select_list) : current_row(0), list_position(0), first_fetch(true), input_sel(STANDARD_VECTOR_SIZE), executor(context) { - // for each UNNEST in the select_list, we add the child expression to the expression executor // and set the return type in the list_data chunk, which will contain the evaluated expression results vector list_data_types; @@ -139,7 +138,6 @@ OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, Da OperatorState &state_p, const vector> &select_list, bool include_input) { - auto &state = state_p.Cast(); do { diff --git a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp index f842fca30..5e15f1342 100644 --- a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp @@ -53,8 +53,8 @@ unique_ptr PhysicalColumnDataScan::GetLocalSourceState(Executi return make_uniq(); } -SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalColumnDataScan::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); collection->Scan(gstate.global_scan_state, lstate.local_scan_state, chunk); diff --git a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp index 1a620803b..bb1948193 100644 --- a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp @@ -2,8 +2,8 @@ namespace duckdb { -SourceResultType PhysicalDummyScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalDummyScan::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { // return a single row on the first call to the dummy scan chunk.SetCardinality(1); diff --git a/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp b/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp index 2e7d006bf..ee54b460a 100644 --- a/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp @@ -2,8 +2,8 @@ namespace duckdb { -SourceResultType PhysicalEmptyResult::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalEmptyResult::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp index bff24d785..e618a7e24 100644 --- a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp @@ -14,7 +14,6 @@ PhysicalPositionalScan::PhysicalPositionalScan(PhysicalPlan &physical_plan, vect PhysicalOperator &left, PhysicalOperator &right) : PhysicalOperator(physical_plan, PhysicalOperatorType::POSITIONAL_SCAN, std::move(types), MaxValue(left.estimated_cardinality, right.estimated_cardinality)) { - // Manage the children ourselves if (left.type == PhysicalOperatorType::TABLE_SCAN) { child_tables.emplace_back(left); @@ -67,10 +66,14 @@ class PositionalTableScanner { InterruptState interrupt_state; OperatorSourceInput source_input {global_state, *local_state, interrupt_state}; - auto source_result = table.GetData(context, source, source_input); - if (source_result == SourceResultType::BLOCKED) { - throw NotImplementedException( - "Unexpected interrupt from table Source in PositionalTableScanner refill"); + auto source_result = SourceResultType::HAVE_MORE_OUTPUT; + while (source_result == SourceResultType::HAVE_MORE_OUTPUT && source.size() == 0) { + // TODO: this could as well just be propagated further, but for now iterating it is + source_result = table.GetData(context, source, source_input); + if (source_result == SourceResultType::BLOCKED) { + throw NotImplementedException( + "Unexpected interrupt from table Source in PositionalTableScanner refill"); + } } } source_offset = 0; @@ -154,8 +157,8 @@ unique_ptr PhysicalPositionalScan::GetGlobalSourceState(Clien return make_uniq(context, *this); } -SourceResultType PhysicalPositionalScan::GetData(ExecutionContext &context, DataChunk &output, - OperatorSourceInput &input) const { +SourceResultType PhysicalPositionalScan::GetDataInternal(ExecutionContext &context, DataChunk &output, + OperatorSourceInput &input) const { auto &lstate = input.local_state.Cast(); // Find the longest source block diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index e9f66bea4..aa3a859c9 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -4,6 +4,9 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/transaction/transaction.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" +#include "duckdb/main/settings.hpp" #include @@ -16,6 +19,7 @@ PhysicalTableScan::PhysicalTableScan(PhysicalPlan &physical_plan, vector parameters_p, virtual_column_map_t virtual_columns_p) : PhysicalOperator(physical_plan, PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), + function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), table_filters(std::move(table_filters_p)), extra_info(std::move(extra_info)), parameters(std::move(parameters_p)), @@ -25,6 +29,9 @@ PhysicalTableScan::PhysicalTableScan(PhysicalPlan &physical_plan, vector(context); + if (op.dynamic_filters && op.dynamic_filters->HasFilters()) { table_filters = op.dynamic_filters->GetFinalTableFilters(op, op.table_filters.get()); } @@ -56,6 +63,7 @@ class TableScanGlobalSourceState : public GlobalSourceState { } idx_t max_threads = 0; + PhysicalTableScanExecutionStrategy physical_table_scan_execution_strategy; unique_ptr global_state; bool in_out_final = false; DataChunk input_chunk; @@ -93,8 +101,63 @@ unique_ptr PhysicalTableScan::GetGlobalSourceState(ClientCont return make_uniq(context, *this); } -SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +static void ValidateAsyncStrategyResult(const PhysicalTableScanExecutionStrategy &strategy, + const AsyncResultsExecutionMode &execution_mode_pre, + const AsyncResultsExecutionMode &execution_mode_post, + const AsyncResultType &result_pre, const AsyncResultType &result_post, + const idx_t output_chunk_size) { + auto execution_mode_pre_computed = AsyncResult::ConvertToAsyncResultExecutionMode(strategy); + if (execution_mode_pre_computed != execution_mode_pre) { + throw InternalException("ValidateAsyncStrategyResult: invalid conversion PhysicalTableScanExecutionStrategy to " + "AsyncResultsExecutionMode, from '%s', to '%s'", + EnumUtil::ToChars(strategy), EnumUtil::ToChars(execution_mode_pre)); + } + + if (execution_mode_pre != execution_mode_post) { + throw InternalException("ValidateAsyncStrategyResult: results_execution_mode changed within table API's " + "`function` call, before '%s', after '%s'", + EnumUtil::ToChars(execution_mode_pre), EnumUtil::ToChars(execution_mode_post)); + } + if (result_pre != AsyncResultType::IMPLICIT) { + throw InternalException("ValidateAsyncStrategyResult: async_result is supposed to be IMPLICIT, was '%s', " + "before table API's `function` call", + EnumUtil::ToChars(result_pre)); + } + switch (strategy) { + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS: + // This is a funny one, expected to throw on non-trivial workflows in this function + case PhysicalTableScanExecutionStrategy::SYNCHRONOUS: + switch (result_post) { + case AsyncResultType::INVALID: + throw InternalException("ValidateAsyncStrategyResult: found INVALID"); + case AsyncResultType::BLOCKED: + throw InternalException("ValidateAsyncStrategyResult: found BLOCKED"); + case AsyncResultType::FINISHED: + if (output_chunk_size > 0) { + throw InternalException("ValidateAsyncStrategyResult: found FINISHED with non-empty chunk"); + } + break; + case AsyncResultType::HAVE_MORE_OUTPUT: + if (output_chunk_size == 0) { + throw InternalException("ValidateAsyncStrategyResult: found HAVE_MORE_OUTPUT with empty chunk"); + } + break; + case AsyncResultType::IMPLICIT: + break; + } + break; + default: + if (result_post == AsyncResultType::BLOCKED) { + if (output_chunk_size > 0) { + throw InternalException("ValidateAsyncStrategyResult: found BLOCKED with non-empty chunk"); + } + } + break; + } +} + +SourceResultType PhysicalTableScan::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { D_ASSERT(!column_ids.empty()); auto &g_state = input.global_state.Cast(); auto &l_state = input.local_state.Cast(); @@ -102,15 +165,55 @@ SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk TableFunctionInput data(bind_data.get(), l_state.local_state.get(), g_state.global_state.get()); if (function.function) { + data.async_result = AsyncResultType::IMPLICIT; + + const auto initial_async_result = data.async_result.GetResultType(); + const auto execution_strategy = g_state.physical_table_scan_execution_strategy; + const auto input_execution_mode = AsyncResult::ConvertToAsyncResultExecutionMode(execution_strategy); + data.results_execution_mode = input_execution_mode; + + // Actually call the function function.function(context.client, data, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + + const auto output_async_result = data.async_result.GetResultType(); + + // Compare and check whether state before and after function.function call is compatible, will throw in case of + // inconsistencies + ValidateAsyncStrategyResult(execution_strategy, input_execution_mode, data.results_execution_mode, + initial_async_result, output_async_result, chunk.size()); + + // Handle results + switch (output_async_result) { + case AsyncResultType::BLOCKED: { + D_ASSERT(data.async_result.HasTasks()); + auto guard = g_state.Lock(); + if (g_state.CanBlock(guard)) { + data.async_result.ScheduleTasks(input.interrupt_state, context.pipeline->executor); + return SourceResultType::BLOCKED; + } + return SourceResultType::FINISHED; + } + case AsyncResultType::IMPLICIT: + if (chunk.size() > 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + return SourceResultType::FINISHED; + case AsyncResultType::FINISHED: + return SourceResultType::FINISHED; + case AsyncResultType::HAVE_MORE_OUTPUT: + return SourceResultType::HAVE_MORE_OUTPUT; + default: + throw InternalException( + "PhysicalTableScan::GetData call of function.function returned unexpected return '%'", + EnumUtil::ToChars(data.async_result.GetResultType())); + } + throw InternalException("PhysicalTableScan::GetData hasn't handled a function.function return"); } if (g_state.in_out_final) { function.in_out_function_final(context, data, chunk); } switch (function.in_out_function(context, data, g_state.input_chunk, chunk)) { - case OperatorResultType::BLOCKED: { auto guard = g_state.Lock(); return g_state.BlockSource(guard, input.interrupt_state); @@ -188,6 +291,33 @@ void AddProjectionNames(const ColumnIndex &index, const string &name, const Logi } } +static string GetFilterInfo(const PhysicalTableScan *scan, const unique_ptr &filter_set) { + string filters_info; + bool first_item = true; + for (auto &f : filter_set->filters) { + auto &column_index = f.first; + auto &filter = f.second; + if (column_index < scan->names.size()) { + if (!first_item) { + filters_info += "\n"; + } + first_item = false; + + const auto col_id = scan->column_ids[column_index].GetPrimaryIndex(); + if (IsVirtualColumn(col_id)) { + auto entry = scan->virtual_columns.find(col_id); + if (entry == scan->virtual_columns.end()) { + throw InternalException("Virtual column not found"); + } + filters_info += filter->ToString(entry->second.name); + } else { + filters_info += filter->ToString(scan->names[col_id]); + } + } + } + return filters_info; +} + InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { InsertionOrderPreservingMap result; if (function.to_string) { @@ -214,31 +344,13 @@ InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { result["Projections"] = projections; } if (function.filter_pushdown && table_filters) { - string filters_info; - bool first_item = true; - for (auto &f : table_filters->filters) { - auto &column_index = f.first; - auto &filter = f.second; - if (column_index < names.size()) { - if (!first_item) { - filters_info += "\n"; - } - first_item = false; - - const auto col_id = column_ids[column_index].GetPrimaryIndex(); - if (IsVirtualColumn(col_id)) { - auto entry = virtual_columns.find(col_id); - if (entry == virtual_columns.end()) { - throw InternalException("Virtual column not found"); - } - filters_info += filter->ToString(entry->second.name); - } else { - filters_info += filter->ToString(names[col_id]); - } - } - } - result["Filters"] = filters_info; + result["Filters"] = GetFilterInfo(this, table_filters); } + + if (function.filter_pushdown && dynamic_filters && dynamic_filters->HasFilters()) { + result["Dynamic Filters"] = GetFilterInfo(this, dynamic_filters->GetFinalTableFilters(*this, nullptr)); + } + if (extra_info.sample_options) { result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + "%"; } @@ -259,7 +371,7 @@ bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const { return false; } auto &other = other_p.Cast(); - if (function.function != other.function.function) { + if (function != other.function) { return false; } if (column_ids != other.column_ids) { diff --git a/src/duckdb/src/execution/operator/schema/physical_alter.cpp b/src/duckdb/src/execution/operator/schema/physical_alter.cpp index 957a762c4..b41a15a8b 100644 --- a/src/duckdb/src/execution/operator/schema/physical_alter.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_alter.cpp @@ -9,7 +9,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalAlter::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalAlter::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { if (info->type == AlterType::ALTER_DATABASE) { auto &db_info = info->Cast(); auto &db_manager = DatabaseManager::Get(context.client); diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp index 48e687703..9ec3c152d 100644 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -13,8 +13,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalAttach::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { // parse the options auto &config = DBConfig::GetConfig(context.client); // construct the options @@ -40,7 +40,6 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c if (existing_db) { if ((existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_WRITE) || (!existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_ONLY)) { - auto existing_mode = existing_db->IsReadOnly() ? AccessMode::READ_ONLY : AccessMode::READ_WRITE; auto existing_mode_str = EnumUtil::ToString(existing_mode); auto attached_mode = EnumUtil::ToString(options.access_mode); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp index d21b7bcf1..eee2d4a8d 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp @@ -23,7 +23,6 @@ PhysicalCreateARTIndex::PhysicalCreateARTIndex(PhysicalPlan &physical_plan, Logi : PhysicalOperator(physical_plan, PhysicalOperatorType::CREATE_INDEX, op.types, estimated_cardinality), table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), sorted(sorted), alter_table_info(std::move(alter_table_info)) { - // Convert the logical column ids to physical column ids. for (auto &column_id : column_ids) { storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); @@ -85,7 +84,6 @@ unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionCo } SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) const { - auto &l_state = input.local_state.Cast(); auto row_count = l_state.key_chunk.size(); auto &art = l_state.local_index->Cast(); @@ -105,7 +103,6 @@ SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) co } SinkResultType PhysicalCreateARTIndex::SinkSorted(OperatorSinkInput &input) const { - auto &l_state = input.local_state.Cast(); auto &storage = table.GetStorage(); auto &l_index = l_state.local_index; @@ -172,7 +169,7 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve // Vacuum excess memory and verify. state.global_index->Vacuum(); - D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); + state.global_index->Verify(); state.global_index->VerifyAllocations(); auto &storage = table.GetStorage(); @@ -223,8 +220,8 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateARTIndex::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateARTIndex::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/schema/physical_create_function.cpp b/src/duckdb/src/execution/operator/schema/physical_create_function.cpp index 2521b2082..cbc9d6df1 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_function.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_function.cpp @@ -8,8 +8,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateFunction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateFunction::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = Catalog::GetCatalog(context.client, info->catalog); catalog.CreateFunction(context.client, *info); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp b/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp index b0b031390..50fb4b24b 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp @@ -7,8 +7,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSchema::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateSchema::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = Catalog::GetCatalog(context.client, info->catalog); if (catalog.IsSystemCatalog()) { throw BinderException("Cannot create schema in system catalog"); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp b/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp index 80c4a2ffa..d0a3cbcb5 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp @@ -6,8 +6,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSequence::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateSequence::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = Catalog::GetCatalog(context.client, info->catalog); catalog.CreateSequence(context.client, *info); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_table.cpp b/src/duckdb/src/execution/operator/schema/physical_create_table.cpp index 6b5a9e6b2..7e17518a3 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_table.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_table.cpp @@ -16,8 +16,8 @@ PhysicalCreateTable::PhysicalCreateTable(PhysicalPlan &physical_plan, LogicalOpe //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateTable::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateTable::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = schema.catalog; catalog.CreateTable(catalog.GetCatalogTransaction(context.client), schema, *info); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_type.cpp b/src/duckdb/src/execution/operator/schema/physical_create_type.cpp index a98941bc4..8005b4ebc 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_type.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_type.cpp @@ -70,8 +70,8 @@ SinkResultType PhysicalCreateType::Sink(ExecutionContext &context, DataChunk &ch //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateType::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateType::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { if (IsSink()) { D_ASSERT(info->type == LogicalType::INVALID); auto &g_sink_state = sink_state->Cast(); diff --git a/src/duckdb/src/execution/operator/schema/physical_create_view.cpp b/src/duckdb/src/execution/operator/schema/physical_create_view.cpp index 948adad14..c3d4ba21f 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_view.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_view.cpp @@ -6,8 +6,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateView::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalCreateView::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &catalog = Catalog::GetCatalog(context.client, info->catalog); catalog.CreateView(context.client, *info); diff --git a/src/duckdb/src/execution/operator/schema/physical_detach.cpp b/src/duckdb/src/execution/operator/schema/physical_detach.cpp index 480890c3a..1a2ff0700 100644 --- a/src/duckdb/src/execution/operator/schema/physical_detach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_detach.cpp @@ -11,8 +11,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalDetach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalDetach::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &db_manager = DatabaseManager::Get(context.client); db_manager.DetachDatabase(context.client, info->name, info->if_not_found); diff --git a/src/duckdb/src/execution/operator/schema/physical_drop.cpp b/src/duckdb/src/execution/operator/schema/physical_drop.cpp index 7c9cbf933..ff780278e 100644 --- a/src/duckdb/src/execution/operator/schema/physical_drop.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_drop.cpp @@ -12,7 +12,8 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalDrop::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { +SourceResultType PhysicalDrop::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { switch (info->type) { case CatalogType::PREPARED_STATEMENT: { // DEALLOCATE silently ignores errors diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp index 79420e902..1fd840709 100644 --- a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp @@ -34,7 +34,6 @@ class RecursiveCTEState : public GlobalSinkState { public: explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { - vector payload_aggregates_ptr; for (idx_t i = 0; i < op.payload_aggregates.size(); i++) { auto &dat = op.payload_aggregates[i]; @@ -122,8 +121,8 @@ SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk & //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { +SourceResultType PhysicalRecursiveCTE::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); if (!gstate.initialized) { if (!using_key) { diff --git a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp index 48f05f88a..30f66c0f9 100644 --- a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp @@ -290,7 +290,7 @@ void PerfectAggregateHashTable::Destroy() { // check if there is any destructor to call bool has_destructor = false; for (auto &aggr : layout_ptr->GetAggregates()) { - if (aggr.function.destructor) { + if (aggr.function.HasStateDestructorCallback()) { has_destructor = true; } } diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp index ad51afa31..f10cc6d39 100644 --- a/src/duckdb/src/execution/physical_operator.cpp +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -125,7 +125,12 @@ unique_ptr PhysicalOperator::GetGlobalSourceState(ClientConte // LCOV_EXCL_START SourceResultType PhysicalOperator::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - throw InternalException("Calling GetData on a node that is not a source!"); + return GetDataInternal(context, chunk, input); +} + +SourceResultType PhysicalOperator::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + throw InternalException("Calling GetDataInternal on a node that is not a source!"); } OperatorPartitionData PhysicalOperator::GetPartitionData(ExecutionContext &context, DataChunk &chunk, @@ -301,7 +306,6 @@ bool CachingPhysicalOperator::CanCacheType(const LogicalType &type) { CachingPhysicalOperator::CachingPhysicalOperator(PhysicalPlan &physical_plan, PhysicalOperatorType type, vector types_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, type, std::move(types_p), estimated_cardinality) { - caching_supported = true; for (auto &col_type : types) { if (!CanCacheType(col_type)) { diff --git a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp index fcb2aaef5..8bd9837bd 100644 --- a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp @@ -224,7 +224,7 @@ static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate } for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.combine) { + if (aggregate.IsDistinct() || !aggregate.function.HasStateCombineCallback()) { // distinct aggregates are not supported in perfect hash aggregates return false; } @@ -236,12 +236,12 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { D_ASSERT(op.children.size() == 1); reference plan = CreatePlan(*op.children[0]); - plan = ExtractAggregateExpressions(plan, op.expressions, op.groups); + plan = ExtractAggregateExpressions(plan, op.expressions, op.groups, op.grouping_sets); bool can_use_simple_aggregation = true; for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (!aggregate.function.simple_update) { + if (!aggregate.function.HasStateSimpleUpdateCallback()) { // unsupported aggregate for simple aggregation: use hash aggregation can_use_simple_aggregation = false; break; @@ -305,7 +305,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { PhysicalOperator &PhysicalPlanGenerator::ExtractAggregateExpressions(PhysicalOperator &child, vector> &aggregates, - vector> &groups) { + vector> &groups, + optional_ptr> grouping_sets) { vector> expressions; vector types; @@ -314,7 +315,7 @@ PhysicalOperator &PhysicalPlanGenerator::ExtractAggregateExpressions(PhysicalOpe auto &bound_aggr = aggr->Cast(); if (bound_aggr.order_bys) { // sorted aggregate! - FunctionBinder::BindSortedAggregate(context, bound_aggr, groups); + FunctionBinder::BindSortedAggregate(context, bound_aggr, groups, grouping_sets); } } for (auto &group : groups) { diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index 5759583c5..3d84506fa 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -13,13 +13,13 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { optional_ptr PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOperator &probe, PhysicalOperator &build) { - // Plan a inverse nested loop join, then aggregate the values to choose the optimal match for each probe row. // Use a row number primary key to handle duplicate probe values. // aggregate the fields to produce at most one match per probe row, @@ -27,7 +27,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera // // ∏ * \ pk // | - // Γ pk;first(P),arg_xxx(B,inequality) + // Γ pk;first(P),arg_xxx_null(B,inequality) // | // ∏ *,inequality // | @@ -43,10 +43,9 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera const auto &probe_types = op.children[0]->types; join_op.types.insert(join_op.types.end(), probe_types.begin(), probe_types.end()); - // TODO: We can't handle predicates right now because we would have to remap column references. - if (op.predicate) { - return nullptr; - } + // Project pk + LogicalType pk_type = LogicalType::BIGINT; + join_op.types.emplace_back(pk_type); // Fill in the projection maps to simplify the code below // Since NLJ doesn't support projection, but ASOF does, @@ -65,9 +64,25 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera } } - // Project pk - LogicalType pk_type = LogicalType::BIGINT; - join_op.types.emplace_back(pk_type); + // Remap predicate column references. + if (op.predicate) { + vector swap_projection_map; + const auto rhs_width = op.children[1]->types.size(); + for (const auto &l : join_op.right_projection_map) { + swap_projection_map.emplace_back(l + rhs_width); + } + for (const auto &r : join_op.left_projection_map) { + swap_projection_map.emplace_back(r); + } + join_op.predicate = op.predicate->Copy(); + ExpressionIterator::EnumerateExpression(join_op.predicate, [&](Expression &child) { + if (child.GetExpressionClass() == ExpressionClass::BOUND_REF) { + auto &col_idx = child.Cast().index; + const auto new_idx = swap_projection_map[col_idx]; + col_idx = new_idx; + } + }); + } auto binder = Binder::CreateBinder(context); FunctionBinder function_binder(*binder); @@ -88,13 +103,13 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera case ExpressionType::COMPARE_GREATERTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_max"; + arg_min_max = "arg_max_null"; break; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_LESSTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_min"; + arg_min_max = "arg_min_null"; break; case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOTEQUAL: @@ -208,7 +223,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera auto window_types = probe.GetTypes(); window_types.emplace_back(pk_type); - idx_t probe_cardinality = op.children[0]->EstimateCardinality(context); + const auto probe_cardinality = op.EstimateCardinality(context); auto &window = Make(window_types, std::move(window_select), probe_cardinality); window.children.emplace_back(probe); @@ -275,10 +290,12 @@ PhysicalOperator &PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) } D_ASSERT(asof_idx < op.conditions.size()); - bool force_asof_join = DBConfig::GetSetting(context); - if (!force_asof_join) { - idx_t asof_join_threshold = DBConfig::GetSetting(context); - if (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold) { + // If there is a non-comparison predicate, we have to use NLJ. + const bool has_predicate = op.predicate.get(); + const bool force_asof_join = DBConfig::GetSetting(context); + if (!force_asof_join || has_predicate) { + const idx_t asof_join_threshold = DBConfig::GetSetting(context); + if (has_predicate || (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold)) { auto result = PlanAsOfLoopJoin(op, left, right); if (result) { return *result; diff --git a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp index f1b8aec47..39b6d96e8 100644 --- a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp @@ -65,7 +65,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { if (ClientConfig::GetConfig(context).enable_optimizer) { bool changes_made = false; - auto new_expr = OrderedAggregateOptimizer::Apply(context, *first_aggregate, groups, changes_made); + auto new_expr = + OrderedAggregateOptimizer::Apply(context, *first_aggregate, groups, nullptr, changes_made); if (new_expr) { D_ASSERT(new_expr->return_type == first_aggregate->return_type); D_ASSERT(new_expr->GetExpressionType() == ExpressionType::BOUND_AGGREGATE); @@ -81,7 +82,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { } } - child = ExtractAggregateExpressions(child, aggregates, groups); + child = ExtractAggregateExpressions(child, aggregates, groups, nullptr); // we add a physical hash aggregation in the plan to select the distinct groups auto &group_by = Make(context, aggregate_types, std::move(aggregates), std::move(groups), diff --git a/src/duckdb/src/execution/physical_plan/plan_filter.cpp b/src/duckdb/src/execution/physical_plan/plan_filter.cpp index 292fe1bc8..796e4aeb3 100644 --- a/src/duckdb/src/execution/physical_plan/plan_filter.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_filter.cpp @@ -14,7 +14,6 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalFilter &op) { D_ASSERT(op.children.size() == 1); reference plan = CreatePlan(*op.children[0]); if (!op.expressions.empty()) { - D_ASSERT(!plan.get().GetTypes().empty()); // create a filter if there is anything to filter auto &filter = Make(plan.get().GetTypes(), std::move(op.expressions), op.estimated_cardinality); filter.children.push_back(plan); diff --git a/src/duckdb/src/execution/physical_plan/plan_sample.cpp b/src/duckdb/src/execution/physical_plan/plan_sample.cpp index c88d8c741..65aa2ea9b 100644 --- a/src/duckdb/src/execution/physical_plan/plan_sample.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_sample.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/operator/logical_sample.hpp" #include "duckdb/common/enum_util.hpp" #include "duckdb/common/random_engine.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp index c9cab9e8c..ace9b5c1a 100644 --- a/src/duckdb/src/execution/physical_plan/plan_window.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_window.cpp @@ -2,13 +2,11 @@ #include "duckdb/execution/operator/aggregate/physical_window.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_config.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/planner/operator/logical_window.hpp" -#include - namespace duckdb { PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { @@ -44,12 +42,12 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { // Process the window functions by sharing the partition/order definitions unordered_map projection_map; vector> window_expressions; - idx_t blocking_count = 0; + idx_t streaming_count = 0; auto output_pos = input_width; while (!blocking_windows.empty() || !streaming_windows.empty()) { - const bool process_streaming = blocking_windows.empty(); - auto &remaining = process_streaming ? streaming_windows : blocking_windows; - blocking_count += process_streaming ? 0 : 1; + const bool process_blocking = streaming_windows.empty(); + auto &remaining = process_blocking ? blocking_windows : streaming_windows; + streaming_count += process_blocking ? 0 : 1; // Find all functions that share the partitioning of the first remaining expression auto over_idx = remaining[0]; @@ -122,7 +120,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { } // Chain the new window operator on top of the plan - if (i < blocking_count) { + if (i >= streaming_count) { auto &window = Make(types, std::move(select_list), op.estimated_cardinality); window.children.push_back(plan); plan = window; diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp index 58469efe4..74b420a34 100644 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -30,18 +30,18 @@ PhysicalOperator &PhysicalPlanGenerator::ResolveAndPlan(unique_ptrResolveOperatorTypes(); profiler.EndPhase(); // Resolve the column references. - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING); + profiler.StartPhase(MetricType::PHYSICAL_PLANNER_COLUMN_BINDING); ColumnBindingResolver resolver; resolver.VisitOperator(*op); profiler.EndPhase(); // Create the main physical plan. - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN); + profiler.StartPhase(MetricType::PHYSICAL_PLANNER_CREATE_PLAN); physical_plan = PlanInternal(*op); profiler.EndPhase(); diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp index 7265a23e4..9cd32cc5e 100644 --- a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp +++ b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp @@ -208,7 +208,6 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R any_combined(false), radix_ht(radix_ht_p), config(*this), stored_allocators_size(0), finalize_done(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0), max_partition_size(0) { - // Compute minimum reservation auto block_alloc_size = BufferManager::GetBufferManager(context).GetBlockAllocSize(); auto tuples_per_block = block_alloc_size / radix_ht.GetLayout().GetRowWidth(); @@ -467,8 +466,8 @@ void MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra // We're approaching the memory limit, unpin the data if (!lstate.abandoned_data) { lstate.abandoned_data = make_uniq( - BufferManager::GetBufferManager(context), gstate.radix_ht.GetLayoutPtr(), config.GetRadixBits(), - gstate.radix_ht.GetLayout().ColumnCount() - 1); + BufferManager::GetBufferManager(context), gstate.radix_ht.GetLayoutPtr(), MemoryTag::HASH_TABLE, + config.GetRadixBits(), gstate.radix_ht.GetLayout().ColumnCount() - 1); } ht.SetRadixBits(gstate.config.GetRadixBits()); ht.AcquirePartitionedData()->Repartition(context, *lstate.abandoned_data); @@ -827,8 +826,8 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob partition.progress = 1; // Move the combined data back to the partition - partition.data = - make_uniq(BufferManager::GetBufferManager(gstate.context), sink.radix_ht.GetLayoutPtr()); + partition.data = make_uniq(BufferManager::GetBufferManager(gstate.context), + sink.radix_ht.GetLayoutPtr(), MemoryTag::HASH_TABLE); partition.data->Combine(*ht->AcquirePartitionedData()->GetPartitions()[0]); // Update thread-global state @@ -948,14 +947,16 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D for (idx_t i = 0; i < op.aggregates.size(); i++) { D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = op.aggregates[i]->Cast(); - auto aggr_state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, aggr_state.get()); + auto aggr_state = + make_unsafe_uniq_array_uninitialized(aggr.function.GetStateSizeCallback()(aggr.function)); + aggr.function.GetStateInitCallback()(aggr.function, aggr_state.get()); AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get()))); - aggr.function.finalize(state_vector, aggr_input_data, chunk.data[null_groups.size() + i], 1, 0); - if (aggr.function.destructor) { - aggr.function.destructor(state_vector, aggr_input_data, 1); + aggr.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, + chunk.data[null_groups.size() + i], 1, 0); + if (aggr.function.HasStateDestructorCallback()) { + aggr.function.GetStateDestructorCallback()(state_vector, aggr_input_data, 1); } } // Place the grouping values (all the groups of the grouping_set condensed into a single value) diff --git a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp b/src/duckdb/src/execution/sample/base_reservoir_sample.cpp index 35de6d54f..3be480eaf 100644 --- a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp +++ b/src/duckdb/src/execution/sample/base_reservoir_sample.cpp @@ -60,7 +60,6 @@ void BaseReservoirSampling::SetNextEntry() { } void BaseReservoirSampling::ReplaceElementWithIndex(idx_t entry_index, double with_weight, bool pop) { - if (pop) { reservoir_weights.pop(); } diff --git a/src/duckdb/src/execution/sample/reservoir_sample.cpp b/src/duckdb/src/execution/sample/reservoir_sample.cpp index cb52f3f2b..a603bc19f 100644 --- a/src/duckdb/src/execution/sample/reservoir_sample.cpp +++ b/src/duckdb/src/execution/sample/reservoir_sample.cpp @@ -190,7 +190,6 @@ void ReservoirSample::Vacuum() { } unique_ptr ReservoirSample::Copy() const { - auto ret = make_uniq(sample_count); ret->stats_sample = stats_sample; @@ -271,7 +270,7 @@ void ReservoirSample::SimpleMerge(ReservoirSample &other) { auto weight_tuples_this = static_cast(GetTuplesSeen()) / static_cast(total_seen); auto weight_tuples_other = static_cast(other.GetTuplesSeen()) / static_cast(total_seen); - // If weights don't add up to 1, most likely a simple merge occured and no new samples were added. + // If weights don't add up to 1, most likely a simple merge occurred and no new samples were added. // if that is the case, add the missing weight to the lower weighted sample to adjust. // this is to avoid cases where if you have a 20k row table and add another 20k rows row by row // then eventually the missing weights will add up, and get you a more even distribution @@ -564,7 +563,6 @@ T ReservoirSample::GetReservoirChunkCapacity() const { } idx_t ReservoirSample::FillReservoir(DataChunk &chunk) { - idx_t ingested_count = 0; if (!reservoir_chunk) { if (chunk.size() > FIXED_SAMPLE_SIZE) { @@ -609,7 +607,6 @@ SelectionVectorHelper ReservoirSample::GetReplacementIndexes(idx_t sample_chunk_ } SelectionVectorHelper ReservoirSample::GetReplacementIndexesFast(idx_t sample_chunk_offset, idx_t chunk_length) { - // how much weight to the other tuples have compared to the ones in this chunk? auto weight_tuples_other = static_cast(chunk_length) / static_cast(GetTuplesSeen() + chunk_length); auto num_to_pop = static_cast(round(weight_tuples_other * static_cast(sample_count))); diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp index 41af395aa..29a3ff065 100644 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -232,22 +232,22 @@ AggregateFunction CountFunctionBase::GetFunction() { AggregateFunction::StateFinalize, FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); fun.name = "count"; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); return fun; } AggregateFunction CountStarFun::GetFunction() { auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); fun.name = "count_star"; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = CountStarFunction::Window; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + fun.SetWindowCallback(CountStarFunction::Window); return fun; } AggregateFunctionSet CountFun::GetFunctions() { AggregateFunction count_function = CountFunctionBase::GetFunction(); - count_function.statistics = CountPropagateStats; + count_function.SetStatisticsCallback(CountPropagateStats); AggregateFunctionSet count("count"); count.AddFunction(count_function); // the count function can also be called without arguments diff --git a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp index 442eec461..0eed72d84 100644 --- a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp +++ b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp @@ -213,7 +213,7 @@ struct FirstVectorFunction : FirstFunctionStringBase { static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; + function.SetReturnType(arguments[0]->return_type); return nullptr; } }; @@ -233,7 +233,7 @@ void FirstFunctionSimpleUpdate(Vector inputs[], AggregateInputData &aggregate_in template AggregateFunction GetFirstAggregateTemplated(LogicalType type) { auto result = AggregateFunction::UnaryAggregate, T, T, FirstFunction>(type, type); - result.simple_update = FirstFunctionSimpleUpdate; + result.SetStateSimpleUpdateCallback(FirstFunctionSimpleUpdate); return result; } @@ -260,7 +260,7 @@ AggregateFunction GetFirstFunction(const LogicalType &type) { type.Verify(); AggregateFunction function = GetDecimalFirstFunction(type); function.arguments[0] = type; - function.return_type = type; + function.SetReturnType(type); return function; } switch (type.InternalType()) { @@ -317,8 +317,8 @@ unique_ptr BindDecimalFirst(ClientContext &context, AggregateFunct auto name = std::move(function.name); function = GetFirstFunction(decimal_type); function.name = std::move(name); - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - function.return_type = decimal_type; + function.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); + function.SetReturnType(decimal_type); return nullptr; } @@ -337,9 +337,9 @@ unique_ptr BindFirst(ClientContext &context, AggregateFunction &fu auto name = std::move(function.name); function = GetFirstOperator(input_type); function.name = std::move(name); - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - if (function.bind) { - return function.bind(context, function, arguments); + function.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); + if (function.HasBindCallback()) { + return function.GetBindCallback()(context, function, arguments); } else { return nullptr; } diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index ce5ef12af..ce8f80aea 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -296,7 +296,7 @@ struct VectorMinMaxBase { static unique_ptr Bind(ClientContext &context, AggregateFunction &function, vector> &arguments) { function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; + function.SetReturnType(arguments[0]->return_type); return nullptr; } }; @@ -367,8 +367,8 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f // Bind function like arg_min/arg_max. function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; + function.SetReturnType(arguments[0]->return_type); + return make_uniq(); } } @@ -379,10 +379,10 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f auto name = std::move(function.name); function = GetMinMaxOperator(input_type); function.name = std::move(name); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - if (function.bind) { - return function.bind(context, function, arguments); + function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); + function.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); + if (function.HasBindCallback()) { + return function.GetBindCallback()(context, function, arguments); } else { return nullptr; } @@ -431,7 +431,6 @@ class MinMaxNState { template void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, idx_t count) { - auto &val_vector = inputs[0]; auto &n_vector = inputs[1]; @@ -441,7 +440,7 @@ void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_ auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); + STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, true); n_vector.ToUnifiedFormat(count, n_format); state_vector.ToUnifiedFormat(count, state_format); @@ -484,13 +483,13 @@ void SpecializeMinMaxNFunction(AggregateFunction &function) { using STATE = MinMaxNState; using OP = MinMaxNOperation; - function.state_size = AggregateFunction::StateSize; - function.initialize = AggregateFunction::StateInitialize; - function.combine = AggregateFunction::StateCombine; - function.destructor = AggregateFunction::StateDestroy; + function.SetStateSizeCallback(AggregateFunction::StateSize); + function.SetStateInitCallback(AggregateFunction::StateInitialize); + function.SetStateCombineCallback(AggregateFunction::StateCombine); + function.SetStateDestructorCallback(AggregateFunction::StateDestroy); - function.finalize = MinMaxNOperation::Finalize; - function.update = MinMaxNUpdate; + function.SetStateFinalizeCallback(MinMaxNOperation::Finalize); + function.SetStateUpdateCallback(MinMaxNUpdate); } template @@ -520,7 +519,6 @@ void SpecializeMinMaxNFunction(PhysicalType arg_type, AggregateFunction &functio template unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction &function, vector> &arguments) { - for (auto &arg : arguments) { if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); @@ -532,7 +530,7 @@ unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction & // Specialize the function based on the input types SpecializeMinMaxNFunction(val_type, function); - function.return_type = LogicalType::LIST(arguments[0]->return_type); + function.SetReturnType(LogicalType::LIST(arguments[0]->return_type)); return nullptr; } diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index bde4c1479..ea78de81b 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -10,6 +10,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/parallel/thread_context.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -24,7 +25,6 @@ struct SortedAggregateBindData : public FunctionData { BindInfoPtr &bind_info, OrderBys &order_bys) : context(context), function(aggregate), bind_info(std::move(bind_info)), threshold(DBConfig::GetSetting(context)) { - // Describe the arguments. for (const auto &child : children) { buffered_cols.emplace_back(buffered_cols.size()); @@ -433,7 +433,6 @@ struct SortedAggregateFunction { static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, idx_t count, DataChunk &buffered) { - // Only reference the buffered columns buffered.InitializeEmpty(order_bind.buffered_types); const auto &buffered_cols = order_bind.buffered_cols; @@ -531,7 +530,7 @@ struct SortedAggregateFunction { // Reusable inner state auto &aggr = order_bind.function; - vector agg_state(aggr.state_size(aggr)); + vector agg_state(aggr.GetStateSizeCallback()(aggr)); Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data()))); // State variables @@ -539,11 +538,11 @@ struct SortedAggregateFunction { AggregateInputData aggr_bind_info(bind_info, aggr_input_data.allocator); // Inner aggregate APIs - auto initialize = aggr.initialize; - auto destructor = aggr.destructor; - auto simple_update = aggr.simple_update; - auto update = aggr.update; - auto finalize = aggr.finalize; + auto initialize = aggr.GetStateInitCallback(); + auto destructor = aggr.GetStateDestructorCallback(); + auto simple_update = aggr.GetStateSimpleUpdateCallback(); + auto update = aggr.GetStateUpdateCallback(); + auto finalize = aggr.GetStateFinalizeCallback(); auto sdata = FlatVector::GetData(states); @@ -677,14 +676,15 @@ struct SortedAggregateFunction { } // namespace void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, - const vector> &groups) { + const vector> &groups, + optional_ptr> grouping_sets) { if (!expr.order_bys || expr.order_bys->orders.empty() || expr.children.empty()) { // not a sorted aggregate: return return; } // Remove unnecessary ORDER BY clauses and return if nothing remains if (context.config.enable_optimizer) { - if (expr.order_bys->Simplify(groups)) { + if (expr.order_bys->Simplify(groups, grouping_sets)) { expr.order_bys.reset(); return; } @@ -709,13 +709,14 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE // Replace the aggregate with the wrapper AggregateFunction ordered_aggregate( - bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize, + bound_function.name, arguments, bound_function.GetReturnType(), + AggregateFunction::StateSize, AggregateFunction::StateInitialize, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, - AggregateFunction::StateDestroy, nullptr, + SortedAggregateFunction::Finalize, bound_function.GetNullHandling(), SortedAggregateFunction::SimpleUpdate, + nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::Window); expr.function = std::move(ordered_aggregate); @@ -726,7 +727,7 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr) { // Make implicit orderings explicit auto &aggregate = *expr.aggregate; - if (aggregate.order_dependent == AggregateOrderDependent::ORDER_DEPENDENT && expr.arg_orders.empty()) { + if (aggregate.GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT && expr.arg_orders.empty()) { for (auto &order : expr.orders) { const auto type = order.type; const auto null_order = order.null_order; @@ -741,7 +742,7 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpr } // Remove unnecessary ORDER BY clauses and return if nothing remains if (context.config.enable_optimizer) { - if (BoundOrderModifier::Simplify(expr.arg_orders, expr.partitions)) { + if (BoundOrderModifier::Simplify(expr.arg_orders, expr.partitions, nullptr)) { expr.arg_orders.clear(); return; } @@ -765,12 +766,12 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpr // Replace the aggregate with the wrapper AggregateFunction ordered_aggregate( - aggregate.name, arguments, aggregate.return_type, AggregateFunction::StateSize, + aggregate.name, arguments, aggregate.GetReturnType(), AggregateFunction::StateSize, AggregateFunction::StateInitialize, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, aggregate.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, + SortedAggregateFunction::Finalize, aggregate.GetNullHandling(), SortedAggregateFunction::SimpleUpdate, nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::Window); diff --git a/src/duckdb/src/function/built_in_functions.cpp b/src/duckdb/src/function/built_in_functions.cpp index f13bfe677..58476e22d 100644 --- a/src/duckdb/src/function/built_in_functions.cpp +++ b/src/duckdb/src/function/built_in_functions.cpp @@ -102,7 +102,7 @@ unique_ptr BindExtensionFunction(ClientContext &context, ScalarFun // but the extension is not loaded // try to autoload the extension // first figure out which extension we need to auto-load - auto &function_info = bound_function.function_info->Cast(); + auto &function_info = bound_function.GetExtraFunctionInfo().Cast(); auto &extension_name = function_info.extension; auto &db = *context.db; @@ -120,10 +120,10 @@ unique_ptr BindExtensionFunction(ClientContext &context, ScalarFun // override the function with the extension function bound_function = function_entry.functions.GetFunctionByArguments(context, bound_function.arguments); // call the original bind (if any) - if (!bound_function.bind) { + if (!bound_function.HasBindCallback()) { return nullptr; } - return bound_function.bind(context, bound_function, arguments); + return bound_function.GetBindCallback()(context, bound_function, arguments); } void BuiltinFunctions::AddExtensionFunction(ScalarFunctionSet set) { @@ -154,7 +154,7 @@ void BuiltinFunctions::RegisterExtensionOverloads() { ScalarFunction function(entry.name, std::move(arguments), std::move(return_type), nullptr, BindExtensionFunction); - function.function_info = make_shared_ptr(entry.extension); + function.SetExtraFunctionInfo(entry.extension); if (current_set.name != entry.name) { if (!current_set.name.empty()) { // create set of functions diff --git a/src/duckdb/src/function/cast/array_casts.cpp b/src/duckdb/src/function/cast/array_casts.cpp index 2357a2c2c..ad243ef2c 100644 --- a/src/duckdb/src/function/cast/array_casts.cpp +++ b/src/duckdb/src/function/cast/array_casts.cpp @@ -39,7 +39,6 @@ unique_ptr ArrayBoundCastData::InitArrayLocalState(CastLocal // ARRAY -> ARRAY //------------------------------------------------------------------------------ static bool ArrayToArrayCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_array_size = ArrayType::GetSize(source.GetType()); auto target_array_size = ArrayType::GetSize(result.GetType()); if (source_array_size != target_array_size) { diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp index 4e6ed8b99..606fa9010 100644 --- a/src/duckdb/src/function/cast/cast_function_set.cpp +++ b/src/duckdb/src/function/cast/cast_function_set.cpp @@ -184,7 +184,9 @@ int64_t CastFunctionSet::ImplicitCastCost(optional_ptr context, c old_implicit_casting = DBConfig::GetSetting(*config); } if (old_implicit_casting) { - score = 149; + // very high cost to avoid choosing this cast if any other option is available + // (it should be more costly than casting to TEMPLATE if that is available) + score = 10000000000; } } return score; diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 0c0c1c058..558329f70 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -162,6 +162,8 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return EnumCastSwitch(input, source, target); case LogicalTypeId::ARRAY: return ArrayCastSwitch(input, source, target); + case LogicalTypeId::GEOMETRY: + return GeoCastSwitch(input, source, target); case LogicalTypeId::BIGNUM: return BignumCastSwitch(input, source, target); case LogicalTypeId::AGGREGATE_STATE: diff --git a/src/duckdb/src/function/cast/geo_casts.cpp b/src/duckdb/src/function/cast/geo_casts.cpp new file mode 100644 index 000000000..59595359f --- /dev/null +++ b/src/duckdb/src/function/cast/geo_casts.cpp @@ -0,0 +1,23 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +static bool GeometryToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + UnaryExecutor::Execute( + source, result, count, [&](const string_t &input) -> string_t { return Geometry::ToString(result, input); }); + return true; +} + +BoundCastInfo DefaultCasts::GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + return GeometryToVarcharCast; + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 511d09a86..930231808 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -490,6 +490,8 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::UUID: return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::SQLNULL: return &DefaultCasts::TryVectorNullCast; case LogicalTypeId::VARCHAR: diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp index 97a9354d1..12c60bd75 100644 --- a/src/duckdb/src/function/cast/struct_cast.cpp +++ b/src/duckdb/src/function/cast/struct_cast.cpp @@ -12,8 +12,8 @@ unique_ptr StructBoundCastData::BindStructToStructCast(BindCastIn auto &source_children = StructType::GetChildTypes(source); auto &target_children = StructType::GetChildTypes(target); - auto target_is_unnamed = StructType::IsUnnamed(target); - auto source_is_unnamed = StructType::IsUnnamed(source); + auto target_is_unnamed = target_children.empty() || StructType::IsUnnamed(target); + auto source_is_unnamed = source_children.empty() || StructType::IsUnnamed(source); auto is_unnamed = target_is_unnamed || source_is_unnamed; if (is_unnamed && source_children.size() != target_children.size()) { @@ -268,7 +268,6 @@ StructToMapBoundCastData::InitStructToMapCastLocalState(CastLocalStateParameters } static bool StructToMapCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { // Optimization: if the source vector is constant, we only have a single physical element, so we can set the // result vectortype to ConstantVector as well and set the (logical) count to 1 diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp index 65f018d2b..5a7b5d466 100644 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -56,7 +56,6 @@ static unique_ptr BindToUnionCast(BindCastInput &input, const Log // check if the cast is ambiguous (2 or more casts have the same cost) if (candidates.size() > 1 && candidates[1].cost == selected_cost) { - // collect all the ambiguous types auto message = StringUtil::Format( "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", source, @@ -107,7 +106,6 @@ static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParamet BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(target.id() == LogicalTypeId::UNION); if (StructToUnionCast::AllowImplicitCastFromStruct(source, target)) { return StructToUnionCast::Bind(input, source, target); @@ -130,7 +128,6 @@ BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const Logi // INVALID: UNION(A, B, D) -> UNION(A, B, C) struct UnionUnionBoundCastData : public BoundCastData { - // mapping from source member index to target member index // these are always the same size as the source member count // (since all source members must be present in the target) @@ -284,7 +281,6 @@ static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastPa FlatVector::GetData(result_tag_vector)[row_idx] = UnsafeNumericCast(target_tag); } else { - // Issue: The members of the result is not always flatvectors // In the case of TryNullCast, the result member is constant. FlatVector::SetNull(result, row_idx, true); diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index ca377b326..f29db2b85 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -1,3 +1,4 @@ +#include "yyjson_utils.hpp" #include "duckdb/function/cast/default_casts.hpp" #include "duckdb/common/types/variant.hpp" #include "duckdb/function/scalar/variant_utils.hpp" @@ -49,22 +50,6 @@ struct DecimalConversionPayloadFromVariant { idx_t scale; }; -struct ConvertedJSONHolder { -public: - ~ConvertedJSONHolder() { - if (doc) { - yyjson_mut_doc_free(doc); - } - if (stringified_json) { - free(stringified_json); - } - } - -public: - yyjson_mut_doc *doc = nullptr; - char *stringified_json = nullptr; -}; - } // namespace //===--------------------------------------------------------------------===// @@ -364,6 +349,14 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V SelectionVector child_values_sel; child_values_sel.Initialize(count); + SelectionVector row_sel(0, count); + if (row.IsValid()) { + auto row_index = row.GetIndex(); + for (idx_t i = 0; i < count; i++) { + row_sel[i] = static_cast(row_index); + } + } + for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { auto &child_name = child_types[child_idx].first; @@ -372,14 +365,21 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V VariantPathComponent component; component.key = child_name; component.lookup_mode = VariantChildLookupMode::BY_KEY; - auto collection_result = - VariantUtils::FindChildValues(conversion_data.variant, component, row, child_values_sel, child_data, count); - if (!collection_result.Success()) { - D_ASSERT(collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = collection_result.nested_data_index; - auto row_index = row.IsValid() ? row.GetIndex() : nested_index; + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(conversion_data.variant, component, row_sel, child_values_sel, lookup_validity, + child_data, count); + if (!lookup_validity.AllValid()) { + optional_idx nested_index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + nested_index = i; + break; + } + } + D_ASSERT(nested_index.IsValid()); + auto row_index = row.IsValid() ? row.GetIndex() : nested_index.GetIndex(); auto object_keys = - VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index]); + VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index.GetIndex()]); conversion_data.error = StringUtil::Format("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), component.key); return false; @@ -550,6 +550,11 @@ static bool CastVariant(FromVariantConversionData &conversion_data, Vector &resu return CastVariantToPrimitive>( conversion_data, result, sel, offset, count, row, string_payload); } + case LogicalTypeId::GEOMETRY: { + StringConversionPayload string_payload(result); + return CastVariantToPrimitive>( + conversion_data, result, sel, offset, count, row, string_payload); + } case LogicalTypeId::VARCHAR: { if (target_type.IsJSONType()) { return CastVariantToJSON(conversion_data, result, sel, offset, count, row); @@ -686,6 +691,8 @@ BoundCastInfo DefaultCasts::VariantCastSwitch(BindCastInput &input, const Logica case LogicalTypeId::UUID: case LogicalTypeId::ARRAY: return BoundCastInfo(CastFromVARIANT); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(CastFromVARIANT); case LogicalTypeId::VARCHAR: { return BoundCastInfo(CastFromVARIANT); } diff --git a/src/duckdb/src/function/cast/variant/to_json.cpp b/src/duckdb/src/function/cast/variant/to_json.cpp index 9d35c142c..482fa90c2 100644 --- a/src/duckdb/src/function/cast/variant/to_json.cpp +++ b/src/duckdb/src/function/cast/variant/to_json.cpp @@ -10,6 +10,7 @@ #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/variant_visitor.hpp" using namespace duckdb_yyjson; // NOLINT @@ -17,256 +18,211 @@ namespace duckdb { //! ------------ Variant -> JSON ------------ -yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, - idx_t row, uint32_t values_idx) { - auto index = source.unified.sel->get_index(row); - if (!source.unified.validity.RowIsValid(index)) { - return yyjson_mut_null(doc); - } +namespace { + +struct JSONConverter { + using result_type = yyjson_mut_val *; - //! values - auto &values = UnifiedVariantVector::GetValues(source); - auto values_data = values.GetData(values); - - //! type_ids - auto &type_ids = UnifiedVariantVector::GetValuesTypeId(source); - auto type_ids_data = type_ids.GetData(type_ids); - - //! byte_offsets - auto &byte_offsets = UnifiedVariantVector::GetValuesByteOffset(source); - auto byte_offsets_data = byte_offsets.GetData(byte_offsets); - - //! children - auto &children = UnifiedVariantVector::GetChildren(source); - auto children_data = children.GetData(children); - - //! values_index - auto &values_index = UnifiedVariantVector::GetChildrenValuesIndex(source); - auto values_index_data = values_index.GetData(values_index); - - //! keys_index - auto &keys_index = UnifiedVariantVector::GetChildrenKeysIndex(source); - auto keys_index_data = keys_index.GetData(keys_index); - - //! keys - auto &keys = UnifiedVariantVector::GetKeys(source); - auto keys_data = keys.GetData(keys); - auto &keys_entry = UnifiedVariantVector::GetKeysEntry(source); - auto keys_entry_data = keys_entry.GetData(keys_entry); - - //! list entries - auto keys_list_entry = keys_data[keys.sel->get_index(row)]; - auto children_list_entry = children_data[children.sel->get_index(row)]; - auto values_list_entry = values_data[values.sel->get_index(row)]; - - //! The 'values' data of the value we're currently converting - values_idx += values_list_entry.offset; - auto type_id = static_cast(type_ids_data[type_ids.sel->get_index(values_idx)]); - auto byte_offset = byte_offsets_data[byte_offsets.sel->get_index(values_idx)]; - - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto &value = UnifiedVariantVector::GetData(source); - auto value_data = value.GetData(value); - auto &blob = value_data[value.sel->get_index(row)]; - auto blob_data = const_data_ptr_cast(blob.GetData()); - - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: + static yyjson_mut_val *VisitNull(yyjson_mut_doc *doc) { return yyjson_mut_null(doc); - case VariantLogicalType::BOOL_TRUE: - return yyjson_mut_true(doc); - case VariantLogicalType::BOOL_FALSE: - return yyjson_mut_false(doc); - case VariantLogicalType::INT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); - } - case VariantLogicalType::INT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); } - case VariantLogicalType::INT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitBoolean(bool val, yyjson_mut_doc *doc) { + return val ? yyjson_mut_true(doc) : yyjson_mut_false(doc); } - case VariantLogicalType::INT64: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + template + static yyjson_mut_val *VisitInteger(T val, yyjson_mut_doc *doc) { + throw InternalException("JSONConverter::VisitInteger not implemented!"); } - case VariantLogicalType::INT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTime(dtime_t val, yyjson_mut_doc *doc) { + auto val_str = Time::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT8: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeNanos(dtime_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIME_NS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT16: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimeTZ(dtime_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMETZ(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT32: { - auto val = Load(ptr); - return yyjson_mut_sint(doc, val); + + static yyjson_mut_val *VisitTimestampSec(timestamp_sec_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPSEC(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT64: { - auto val = Load(ptr); - return yyjson_mut_uint(doc, val); + + static yyjson_mut_val *VisitTimestampMs(timestamp_ms_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPMS(val).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UINT128: { - auto val = Load(ptr); - auto val_str = val.ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitTimestamp(timestamp_t val, yyjson_mut_doc *doc) { + auto val_str = Timestamp::ToString(val); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::UUID: { - auto val = Value::UUID(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampNanos(timestamp_ns_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPNS(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::INTERVAL: { - auto val = Value::INTERVAL(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitTimestampTZ(timestamp_tz_t val, yyjson_mut_doc *doc) { + auto val_str = Value::TIMESTAMPTZ(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::FLOAT: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitFloat(float val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DOUBLE: { - auto val = Load(ptr); + + static yyjson_mut_val *VisitDouble(double val, yyjson_mut_doc *doc) { return yyjson_mut_real(doc, val); } - case VariantLogicalType::DATE: { - auto val = Load(ptr); - auto val_str = Date::ToString(date_t(val)); + + static yyjson_mut_val *VisitUUID(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = Value::UUID(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BLOB(const_data_ptr_cast(string_data), string_length).ToString(); + + static yyjson_mut_val *VisitDate(date_t val, yyjson_mut_doc *doc) { + auto val_str = Date::ToString(val); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return yyjson_mut_strncpy(doc, string_data, static_cast(string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - string val_str; - if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - val_str = Decimal::ToString(Load(ptr), width, scale); - } else { - val_str = Decimal::ToString(Load(ptr), width, scale); - } - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::TIME_MICROS: { - auto val = Load(ptr); - auto val_str = Time::ToString(val); + static yyjson_mut_val *VisitInterval(interval_t val, yyjson_mut_doc *doc) { + auto val_str = Value::INTERVAL(val).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIME_MICROS_TZ: { - auto val = Value::TIMETZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitString(const string_t &str, yyjson_mut_doc *doc) { + return yyjson_mut_strncpy(doc, str.GetData(), str.GetSize()); } - case VariantLogicalType::TIMESTAMP_MICROS: { - auto val = Load(ptr); - auto val_str = Timestamp::ToString(val); + + static yyjson_mut_val *VisitBlob(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_SEC: { - auto val = Value::TIMESTAMPSEC(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + static yyjson_mut_val *VisitBignum(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_NANOS: { - auto val = Value::TIMESTAMPNS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitGeometry(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MILIS: { - auto val = Value::TIMESTAMPMS(Load(ptr)); - auto val_str = val.ToString(); + + static yyjson_mut_val *VisitBitstring(const string_t &str, yyjson_mut_doc *doc) { + auto val_str = Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::TIMESTAMP_MICROS_TZ: { - auto val = Value::TIMESTAMPTZ(Load(ptr)); - auto val_str = val.ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + + template + static yyjson_mut_val *VisitDecimal(T val, uint32_t width, uint32_t scale, yyjson_mut_doc *doc) { + string val_str; + if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + val_str = Decimal::ToString(val, static_cast(width), static_cast(scale)); + } else { + throw InternalException("Unhandled decimal type"); + } + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); } - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitArray(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto arr = yyjson_mut_arr(doc); - if (!count) { - return arr; - } - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_index = values_index_data[index]; -#ifdef DEBUG - auto key_id_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(!keys_index.validity.RowIsValid(key_id_index)); -#endif - auto val = ConvertVariantToJSON(doc, source, row, child_index); - if (!val) { - return nullptr; - } - yyjson_mut_arr_add_val(arr, val); + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data, doc); + for (auto &entry : array_items) { + yyjson_mut_arr_add_val(arr, entry); } return arr; } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); + + static yyjson_mut_val *VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, yyjson_mut_doc *doc) { auto obj = yyjson_mut_obj(doc); - if (!count) { - return obj; - } - auto child_index_start = VarintDecode(ptr); - - for (idx_t i = 0; i < count; i++) { - auto children_index = values_index.sel->get_index(children_list_entry.offset + child_index_start + i); - auto child_value_idx = values_index_data[children_index]; - auto val = ConvertVariantToJSON(doc, source, row, child_value_idx); - if (!val) { - return nullptr; - } - auto keys_index_index = keys_index.sel->get_index(children_list_entry.offset + child_index_start + i); - D_ASSERT(keys_index.validity.RowIsValid(keys_index_index)); - auto child_key_id = keys_index_data[keys_index_index]; - auto &key = keys_entry_data[keys_entry.sel->get_index(keys_list_entry.offset + child_key_id)]; - yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, key.GetData(), key.GetSize()), val); + auto object_items = VariantVisitor::VisitObjectItems(variant, row, nested_data, doc); + for (auto &entry : object_items) { + yyjson_mut_obj_put(obj, yyjson_mut_strncpy(doc, entry.first.c_str(), entry.first.size()), entry.second); } return obj; } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIT(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); - } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - auto val_str = Value::BIGNUM(const_data_ptr_cast(string_data), string_length).ToString(); - return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); - } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); + + static yyjson_mut_val *VisitDefault(VariantLogicalType type_id, const_data_ptr_t, yyjson_mut_doc *) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } +}; + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} - return nullptr; +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(int64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(hugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint8_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint16_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint32_t val, yyjson_mut_doc *doc) { + return yyjson_mut_sint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uint64_t val, yyjson_mut_doc *doc) { + return yyjson_mut_uint(doc, val); +} + +template <> +yyjson_mut_val *JSONConverter::VisitInteger(uhugeint_t val, yyjson_mut_doc *doc) { + auto val_str = val.ToString(); + return yyjson_mut_rawncpy(doc, val_str.c_str(), val_str.size()); +} + +} // namespace + +yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const RecursiveUnifiedVectorFormat &source, + idx_t row, uint32_t values_idx) { + UnifiedVariantVectorData variant(source); + return VariantVisitor::Visit(variant, row, values_idx, doc); } } // namespace duckdb diff --git a/src/duckdb/src/function/cast/variant/to_variant.cpp b/src/duckdb/src/function/cast/variant/to_variant.cpp index ad1962d37..7402863e6 100644 --- a/src/duckdb/src/function/cast/variant/to_variant.cpp +++ b/src/duckdb/src/function/cast/variant/to_variant.cpp @@ -8,10 +8,9 @@ #include "duckdb/function/cast/variant/to_variant.hpp" namespace duckdb { - namespace variant { -static void InitializeOffsets(DataChunk &offsets, idx_t count) { +void InitializeOffsets(DataChunk &offsets, idx_t count) { auto keys = OffsetData::GetKeys(offsets); auto children = OffsetData::GetChildren(offsets); auto values = OffsetData::GetValues(offsets); @@ -84,39 +83,6 @@ static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVect selvec_size = keys_offset; } -static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, SelectionVector &sel, - idx_t sel_size) { - auto &keys = VariantVector::GetKeys(variant); - auto &keys_entry = ListVector::GetEntry(keys); - auto keys_entry_data = FlatVector::GetData(keys_entry); - - bool already_sorted = true; - - vector unsorted_to_sorted(dictionary.size()); - auto it = dictionary.begin(); - for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { - auto unsorted_idx = it->second; - if (unsorted_idx != sorted_idx) { - already_sorted = false; - } - unsorted_to_sorted[unsorted_idx] = sorted_idx; - D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); - keys_entry_data[sorted_idx] = it->first; - auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); - keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); - it++; - } - - if (!already_sorted) { - //! Adjust the selection vector to point to the right dictionary index - for (idx_t i = 0; i < sel_size; i++) { - auto &entry = sel[i]; - auto sorted_idx = unsorted_to_sorted[entry]; - entry = sorted_idx; - } - } -} - static bool GatherOffsetsAndSizes(ToVariantSourceData &source, ToVariantGlobalResultData &result, idx_t count) { InitializeOffsets(result.offsets, count); //! First pass - collect sizes/offsets @@ -130,6 +96,9 @@ static bool WriteVariantResultData(ToVariantSourceData &source, ToVariantGlobalR } static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + if (!count) { + return true; + } DataChunk offsets; offsets.Initialize(Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, @@ -163,7 +132,7 @@ static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParam } } - FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); //! Finalize the 'data' auto &blob = VariantVector::GetData(result); auto blob_data = FlatVector::GetData(blob); diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp index d73bc38e5..c077cd87c 100644 --- a/src/duckdb/src/function/cast_rules.cpp +++ b/src/duckdb/src/function/cast_rules.cpp @@ -146,7 +146,6 @@ static int64_t ImplicitCastUSmallint(const LogicalType &to) { static int64_t ImplicitCastUInteger(const LogicalType &to) { switch (to.id()) { - case LogicalTypeId::UBIGINT: case LogicalTypeId::BIGINT: case LogicalTypeId::UHUGEINT: @@ -187,7 +186,6 @@ static int64_t ImplicitCastFloat(const LogicalType &to) { static int64_t ImplicitCastDouble(const LogicalType &to) { switch (to.id()) { - case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: @@ -500,7 +498,6 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) int64_t cost = -1; if (named_struct_cast) { - // Collect the target members in a map for easy lookup case_insensitive_map_t target_members; for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { diff --git a/src/duckdb/src/function/copy_blob.cpp b/src/duckdb/src/function/copy_blob.cpp index 2af12a8c3..398eb9534 100644 --- a/src/duckdb/src/function/copy_blob.cpp +++ b/src/duckdb/src/function/copy_blob.cpp @@ -61,7 +61,6 @@ struct WriteBlobGlobalState final : public GlobalFunctionData { unique_ptr WriteBlobInitializeGlobal(ClientContext &context, FunctionData &bind_data, const string &file_path) { - auto &bdata = bind_data.Cast(); auto &fs = FileSystem::GetFileSystem(context); @@ -102,7 +101,6 @@ void WriteBlobSink(ExecutionContext &context, FunctionData &bind_data, GlobalFun for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { const auto out_idx = vdata.sel->get_index(row_idx); if (vdata.validity.RowIsValid(out_idx)) { - auto &blob = blobs[out_idx]; auto blob_len = blob.GetSize(); auto blob_ptr = blob.GetDataWriteable(); diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index c9ad55d92..8521dfe71 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -334,8 +334,8 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE // Some functions may have an invalid default return type, as they must be bound to infer the return type. // In those cases, we default to SQLNULL. const auto return_type_if_null = - bound_function.return_type.IsComplete() ? bound_function.return_type : LogicalType::SQLNULL; - if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + bound_function.GetReturnType().IsComplete() ? bound_function.GetReturnType() : LogicalType::SQLNULL; + if (bound_function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { for (auto &child : children) { if (child->return_type == LogicalTypeId::SQLNULL) { return make_uniq(Value(return_type_if_null)); @@ -378,7 +378,7 @@ static string ExtractCollation(const vector> &children) { static void PropagateCollations(ClientContext &, ScalarFunction &bound_function, vector> &children) { - if (!RequiresCollationPropagation(bound_function.return_type)) { + if (!RequiresCollationPropagation(bound_function.GetReturnType())) { // we only need to propagate if the function returns a varchar return; } @@ -389,7 +389,7 @@ static void PropagateCollations(ClientContext &, ScalarFunction &bound_function, } // propagate the collation to the return type auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - bound_function.return_type = std::move(collation_type); + bound_function.SetReturnType(std::move(collation_type)); } static void PushCollations(ClientContext &context, ScalarFunction &bound_function, @@ -401,8 +401,8 @@ static void PushCollations(ClientContext &context, ScalarFunction &bound_functio } // push collation into the return type if required auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - if (RequiresCollationPropagation(bound_function.return_type)) { - bound_function.return_type = collation_type; + if (RequiresCollationPropagation(bound_function.GetReturnType())) { + bound_function.SetReturnType(collation_type); } // push collations to the children for (auto &arg : children) { @@ -417,7 +417,7 @@ static void PushCollations(ClientContext &context, ScalarFunction &bound_functio static void HandleCollations(ClientContext &context, ScalarFunction &bound_function, vector> &children) { - switch (bound_function.collation_handling) { + switch (bound_function.GetCollationHandling()) { case FunctionCollationHandling::IGNORE_COLLATIONS: // explicitly ignoring collation handling break; @@ -436,7 +436,6 @@ static void HandleCollations(ClientContext &context, ScalarFunction &bound_funct static void InferTemplateType(ClientContext &context, const LogicalType &source, const LogicalType &target, case_insensitive_map_t> &bindings, const Expression ¤t_expr, const BaseScalarFunction &function) { - if (target.id() == LogicalTypeId::UNKNOWN || target.id() == LogicalTypeId::SQLNULL) { // If the actual type is unknown, we cannot infer anything more. // Therefore, we map all remaining templates in the source to UNKNOWN or SQLNULL, if not already inferred to @@ -517,7 +516,6 @@ static void InferTemplateType(ClientContext &context, const LogicalType &source, case LogicalTypeId::ARRAY: { if ((source.id() == LogicalTypeId::ARRAY || source.id() == LogicalTypeId::LIST) && (target.id() == LogicalTypeId::LIST || target.id() == LogicalTypeId::ARRAY)) { - const auto &source_child = source.id() == LogicalTypeId::LIST ? ListType::GetChildType(source) : ArrayType::GetChildType(source); const auto &target_child = @@ -565,7 +563,6 @@ static void InferTemplateType(ClientContext &context, const LogicalType &source, static void SubstituteTemplateType(LogicalType &type, case_insensitive_map_t> &bindings, const string &function_name) { - // Replace all template types in with their bound concrete types. type = TypeVisitor::VisitReplace(type, [&](const LogicalType &t) -> LogicalType { if (t.id() == LogicalTypeId::TEMPLATE) { @@ -614,8 +611,8 @@ void FunctionBinder::ResolveTemplateTypes(BaseScalarFunction &bound_function, } // If the return type is templated, we need to subsitute it as well - if (bound_function.return_type.IsTemplated()) { - to_substitute.emplace_back(bound_function.return_type); + if (bound_function.GetReturnType().IsTemplated()) { + to_substitute.emplace_back(bound_function.GetReturnType()); } // Finally, substitute all template types in the bound function with their concrete types. @@ -641,37 +638,36 @@ void FunctionBinder::CheckTemplateTypesResolved(const BaseScalarFunction &bound_ VerifyTemplateType(arg, bound_function.name); } VerifyTemplateType(bound_function.varargs, bound_function.name); - VerifyTemplateType(bound_function.return_type, bound_function.name); + VerifyTemplateType(bound_function.GetReturnType(), bound_function.name); } unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, vector> children, bool is_operator, optional_ptr binder) { - // Attempt to resolve template types, before we call the "Bind" callback. ResolveTemplateTypes(bound_function, children); unique_ptr bind_info; - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); - } else if (bound_function.bind_extended) { + if (bound_function.HasBindCallback()) { + bind_info = bound_function.GetBindCallback()(context, bound_function, children); + } else if (bound_function.HasBindExtendedCallback()) { if (!binder) { throw InternalException("Function '%s' has a 'bind_extended' but the FunctionBinder was created without " "a reference to a Binder", bound_function.name); } ScalarFunctionBindInput bind_input(*binder); - bind_info = bound_function.bind_extended(bind_input, bound_function, children); + bind_info = bound_function.GetBindExtendedCallback()(bind_input, bound_function, children); } // After the "bind" callback, we verify that all template types are bound to concrete types. CheckTemplateTypesResolved(bound_function); - if (bound_function.get_modified_databases && binder) { + if (bound_function.HasModifiedDatabasesCallback() && binder) { auto &properties = binder->GetStatementProperties(); FunctionModifiedDatabasesInput input(bind_info, properties); - bound_function.get_modified_databases(context, input); + bound_function.GetModifiedDatabasesCallback()(context, input); } HandleCollations(context, bound_function, children); @@ -679,14 +675,14 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_f // check if we need to add casts to the children CastToFunctionArguments(bound_function, children); - auto return_type = bound_function.return_type; + auto return_type = bound_function.GetReturnType(); unique_ptr result; auto result_func = make_uniq(std::move(return_type), std::move(bound_function), std::move(children), std::move(bind_info), is_operator); - if (result_func->function.bind_expression) { + if (result_func->function.HasBindExpressionCallback()) { // if a bind_expression callback is registered - call it and emit the resulting expression FunctionBindExpressionInput input(context, result_func->bind_info.get(), result_func->children); - result = result_func->function.bind_expression(input); + result = result_func->function.GetBindExpressionCallback()(input); } if (!result) { result = std::move(result_func); @@ -698,12 +694,11 @@ unique_ptr FunctionBinder::BindAggregateFunction(Aggre vector> children, unique_ptr filter, AggregateType aggr_type) { - ResolveTemplateTypes(bound_function, children); unique_ptr bind_info; - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); + if (bound_function.HasBindCallback()) { + bind_info = bound_function.GetBindCallback()(context, bound_function, children); // we may have lost some arguments in the bind children.resize(MinValue(bound_function.arguments.size(), children.size())); } diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index d73467d3a..df12b8c01 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -4,6 +4,7 @@ #include "duckdb/function/scalar/compressed_materialization_functions.hpp" #include "duckdb/function/scalar/date_functions.hpp" #include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/function/scalar/geometry_functions.hpp" #include "duckdb/function/scalar/list_functions.hpp" #include "duckdb/function/scalar/map_functions.hpp" #include "duckdb/function/scalar/variant_functions.hpp" @@ -15,6 +16,7 @@ #include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + namespace duckdb { // Scalar Function @@ -45,6 +47,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(NotLikeFun), DUCKDB_SCALAR_FUNCTION(NotILikeFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorModuloFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StIntersectsExtentFunAlias), DUCKDB_SCALAR_FUNCTION_SET(OperatorMultiplyFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorAddFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorSubtractFun), @@ -78,6 +81,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(ArrayExtractFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayIndexofFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayIntersectFun), DUCKDB_SCALAR_FUNCTION_SET(ArrayLengthFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayPositionFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayResizeFun), @@ -119,6 +123,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(ListExtractFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasFun), DUCKDB_SCALAR_FUNCTION_ALIAS(ListIndexofFun), + DUCKDB_SCALAR_FUNCTION(ListIntersectFun), DUCKDB_SCALAR_FUNCTION(ListPositionFun), DUCKDB_SCALAR_FUNCTION_SET(ListResizeFun), DUCKDB_SCALAR_FUNCTION(ListSelectFun), @@ -151,6 +156,12 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(SHA1Fun), DUCKDB_SCALAR_FUNCTION_SET(SHA256Fun), DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StAsbinaryFun), + DUCKDB_SCALAR_FUNCTION(StAstextFun), + DUCKDB_SCALAR_FUNCTION(StAswkbFun), + DUCKDB_SCALAR_FUNCTION_ALIAS(StAswktFun), + DUCKDB_SCALAR_FUNCTION(StGeomfromwkbFun), + DUCKDB_SCALAR_FUNCTION(StIntersectsExtentFun), DUCKDB_SCALAR_FUNCTION_ALIAS(StrSplitFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(StrSplitRegexFun), DUCKDB_SCALAR_FUNCTION_SET(StrfTimeFun), @@ -177,6 +188,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), DUCKDB_SCALAR_FUNCTION(UpperFun), DUCKDB_SCALAR_FUNCTION_SET(VariantExtractFun), + DUCKDB_SCALAR_FUNCTION(VariantNormalizeFun), DUCKDB_SCALAR_FUNCTION(VariantTypeofFun), DUCKDB_SCALAR_FUNCTION_SET(WriteLogFun), DUCKDB_SCALAR_FUNCTION(ConcatOperatorFun), diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp index 2f407c025..0e7c09958 100644 --- a/src/duckdb/src/function/macro_function.cpp +++ b/src/duckdb/src/function/macro_function.cpp @@ -45,16 +45,33 @@ MacroBindResult MacroFunction::BindMacroFunction( Binder &binder, const vector> &functions, const string &name, FunctionExpression &function_expr, vector> &positional_arguments, InsertionOrderPreservingMap> &named_arguments, idx_t depth) { - ExpressionBinder expr_binder(binder, binder.context); + expr_binder.lambda_bindings = binder.lambda_bindings; + + // Figure out whether we even need to bind arguments + bool requires_bind = false; + for (auto &function : functions) { + for (const auto &type : function->types) { + if (type.id() != LogicalTypeId::UNKNOWN) { + requires_bind = true; + break; + } + } + if (requires_bind) { + break; + } + } // Find argument types and separate positional and default arguments vector positional_arg_types; InsertionOrderPreservingMap named_arg_types; for (auto &arg : function_expr.children) { auto arg_copy = arg->Copy(); - const auto arg_bind_result = expr_binder.BindExpression(arg_copy, depth + 1); - auto arg_type = arg_bind_result.HasError() ? LogicalType::UNKNOWN : arg_bind_result.expression->return_type; + LogicalType arg_type = LogicalType::UNKNOWN; + if (requires_bind) { + const auto arg_bind_result = expr_binder.BindExpression(arg_copy, depth + 1); + arg_type = arg_bind_result.HasError() ? LogicalType::UNKNOWN : arg_bind_result.expression->return_type; + } if (!arg->GetAlias().empty()) { // Default argument if (named_arguments.find(arg->GetAlias()) != named_arguments.end()) { diff --git a/src/duckdb/src/function/pragma/pragma_queries.cpp b/src/duckdb/src/function/pragma/pragma_queries.cpp index 9107b8c01..62ce195df 100644 --- a/src/duckdb/src/function/pragma/pragma_queries.cpp +++ b/src/duckdb/src/function/pragma/pragma_queries.cpp @@ -11,6 +11,7 @@ #include "duckdb/parser/qualified_name.hpp" #include "duckdb/parser/statement/copy_statement.hpp" #include "duckdb/parser/statement/export_statement.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp index 740924397..2c4b72ead 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp @@ -190,14 +190,14 @@ scalar_function_t GetIntegralDecompressFunctionInputSwitch(const LogicalType &in void CMIntegralSerialize(Serializer &serializer, const optional_ptr bind_data, const ScalarFunction &function) { serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); } template unique_ptr CMIntegralDeserialize(Deserializer &deserializer, ScalarFunction &function) { function.arguments = deserializer.ReadProperty>(100, "arguments"); auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GET_FUNCTION(function.arguments[0], return_type); + function.SetFunctionCallback(GET_FUNCTION(function.arguments[0], return_type)); return nullptr; } @@ -226,12 +226,12 @@ ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_typ ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { ScalarFunction result(IntegralCompressFunctionName(result_type), {input_type, input_type}, result_type, GetIntegralCompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; + result.SetSerializeCallback(CMIntegralSerialize); + result.SetDeserializeCallback(CMIntegralDeserialize); #if defined(D_ASSERT_IS_ENABLED) - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled + result.SetFallible(); // Can only throw runtime error when assertions are enabled #else - result.errors = FunctionErrors::CANNOT_ERROR; + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); #endif return result; } @@ -239,8 +239,8 @@ ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, ScalarFunction CMIntegralDecompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { ScalarFunction result(IntegralDecompressFunctionName(result_type), {input_type, result_type}, result_type, GetIntegralDecompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; + result.SetSerializeCallback(CMIntegralSerialize); + result.SetDeserializeCallback(CMIntegralDeserialize); return result; } diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp index 39821858d..80fda3a0c 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -44,7 +44,7 @@ inline RESULT_TYPE StringCompressInternal(const string_t &input) { memset(result_ptr, '\0', remainder); } result_ptr[0] = UnsafeNumericCast(input.GetSize()); - return result; + return BSwapIfBE(result); } template @@ -55,13 +55,15 @@ inline RESULT_TYPE StringCompress(const string_t &input) { template inline RESULT_TYPE MiniStringCompress(const string_t &input) { + RESULT_TYPE result; if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { - return UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPrefix())); + result = UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPrefix())); } else if (input.GetSize() == 0) { - return 0; + result = 0; } else { - return UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPointer())); + result = UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPointer())); } + return BSwapIfBE(result); } template <> @@ -126,19 +128,20 @@ struct StringDecompressLocalState : public FunctionLocalState { template inline string_t StringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - const auto input_ptr = const_data_ptr_cast(&input); - string_t result(input_ptr[0]); + const auto le_input = BSwapIfBE(input); + const auto le_input_str = const_data_ptr_cast(&le_input); + string_t result(le_input_str[0]); if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr); + TemplatedReverseMemCpy(result_ptr, le_input_str); memset(result_ptr + sizeof(INPUT_TYPE) - 1, '\0', string_t::INLINE_LENGTH - sizeof(INPUT_TYPE) + 1); } else if (result.GetSize() <= string_t::INLINE_LENGTH) { static constexpr auto REMAINDER = sizeof(INPUT_TYPE) - string_t::INLINE_LENGTH; const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr + REMAINDER); + TemplatedReverseMemCpy(result_ptr, le_input_str + REMAINDER); } else { result.SetPointer(char_ptr_cast(allocator.Allocate(sizeof(INPUT_TYPE)))); - TemplatedReverseMemCpy(data_ptr_cast(result.GetPointer()), input_ptr); + TemplatedReverseMemCpy(data_ptr_cast(result.GetPointer()), le_input_str); memcpy(result.GetPrefixWriteable(), result.GetPointer(), string_t::PREFIX_LENGTH); } return result; @@ -146,7 +149,8 @@ inline string_t StringDecompress(const INPUT_TYPE &input, ArenaAllocator &alloca template inline string_t MiniStringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - if (input == 0) { + const auto le_input = BSwapIfBE(input); + if (le_input == 0) { string_t result(uint32_t(0)); memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); return result; @@ -155,10 +159,10 @@ inline string_t MiniStringDecompress(const INPUT_TYPE &input, ArenaAllocator &al string_t result(1); if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); - *data_ptr_cast(result.GetPrefixWriteable()) = input - 1; + *data_ptr_cast(result.GetPrefixWriteable()) = le_input - 1; } else { result.SetPointer(char_ptr_cast(allocator.Allocate(1))); - *data_ptr_cast(result.GetPointer()) = input - 1; + *data_ptr_cast(result.GetPointer()) = le_input - 1; memset(result.GetPrefixWriteable(), '\0', string_t::PREFIX_LENGTH); *result.GetPrefixWriteable() = *result.GetPointer(); } @@ -198,7 +202,7 @@ scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_typ case LogicalTypeId::UHUGEINT: return GetStringDecompressFunction(input_type); case LogicalTypeId::HUGEINT: - return GetStringCompressFunction(input_type); + return GetStringDecompressFunction(input_type); default: throw InternalException("Unexpected type in GetStringDecompressFunctionSwitch"); } @@ -207,13 +211,13 @@ scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_typ void CMStringCompressSerialize(Serializer &serializer, const optional_ptr bind_data, const ScalarFunction &function) { serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); } unique_ptr CMStringCompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { function.arguments = deserializer.ReadProperty>(100, "arguments"); auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GetStringCompressFunctionSwitch(return_type); + function.SetFunctionCallback(GetStringCompressFunctionSwitch(return_type)); return nullptr; } @@ -224,8 +228,8 @@ void CMStringDecompressSerialize(Serializer &serializer, const optional_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { function.arguments = deserializer.ReadProperty>(100, "arguments"); - function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); - function.return_type = deserializer.Get(); + function.SetFunctionCallback(GetStringDecompressFunctionSwitch(function.arguments[0])); + function.SetReturnType(deserializer.Get()); return nullptr; } @@ -245,12 +249,12 @@ ScalarFunctionSet GetStringDecompressFunctionSet() { ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) { ScalarFunction result(StringCompressFunctionName(result_type), {LogicalType::VARCHAR}, result_type, GetStringCompressFunctionSwitch(result_type), CMUtils::Bind); - result.serialize = CMStringCompressSerialize; - result.deserialize = CMStringCompressDeserialize; + result.SetSerializeCallback(CMStringCompressSerialize); + result.SetDeserializeCallback(CMStringCompressDeserialize); #if defined(D_ASSERT_IS_ENABLED) - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled + result.SetFallible(); // Can only throw runtime error when assertions are enabled #else - result.errors = FunctionErrors::CANNOT_ERROR; + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); #endif return result; } @@ -259,8 +263,8 @@ ScalarFunction CMStringDecompressFun::GetFunction(const LogicalType &input_type) ScalarFunction result(StringDecompressFunctionName(), {input_type}, LogicalType::VARCHAR, GetStringDecompressFunctionSwitch(input_type), CMUtils::Bind, nullptr, nullptr, StringDecompressLocalState::Init); - result.serialize = CMStringDecompressSerialize; - result.deserialize = CMStringDecompressDeserialize; + result.SetSerializeCallback(CMStringDecompressSerialize); + result.SetDeserializeCallback(CMStringDecompressDeserialize); return result; } diff --git a/src/duckdb/src/function/scalar/create_sort_key.cpp b/src/duckdb/src/function/scalar/create_sort_key.cpp index 2f5463e3f..9c043d4e6 100644 --- a/src/duckdb/src/function/scalar/create_sort_key.cpp +++ b/src/duckdb/src/function/scalar/create_sort_key.cpp @@ -63,7 +63,7 @@ unique_ptr CreateSortKeyBind(ClientContext &context, ScalarFunctio } if (all_constant) { if (constant_size <= sizeof(int64_t)) { - bound_function.return_type = LogicalType::BIGINT; + bound_function.SetReturnType(LogicalType::BIGINT); } } return std::move(result); @@ -696,20 +696,22 @@ void PrepareSortData(Vector &result, idx_t size, SortKeyLengthInfo &key_lengths, } } -void FinalizeSortData(Vector &result, idx_t size) { +void FinalizeSortData(Vector &result, idx_t size, const SortKeyLengthInfo &key_lengths, + const unsafe_vector &offsets) { switch (result.GetType().id()) { case LogicalTypeId::BLOB: { auto result_data = FlatVector::GetData(result); // call Finalize on the result for (idx_t r = 0; r < size; r++) { - result_data[r].Finalize(); + result_data[r].SetSizeAndFinalize(NumericCast(offsets[r]), + key_lengths.variable_lengths[r] + key_lengths.constant_length); } break; } case LogicalTypeId::BIGINT: { auto result_data = FlatVector::GetData(result); for (idx_t r = 0; r < size; r++) { - result_data[r] = BSwap(result_data[r]); + result_data[r] = BSwapIfLE(result_data[r]); } break; } @@ -739,7 +741,7 @@ void CreateSortKeyInternal(vector> &sort_key_data, SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); ConstructSortKey(*sort_key_data[c], info); } - FinalizeSortData(result, row_count); + FinalizeSortData(result, row_count, key_lengths, offsets); } } // namespace @@ -861,7 +863,7 @@ unique_ptr DecodeSortKeyBind(ClientContext &context, ScalarFunctio throw BinderException("sort_key must be either BIGINT or BLOB, got %s instead", sort_key_arg.return_type.ToString()); } - bound_function.return_type = LogicalType::STRUCT(std::move(children)); + bound_function.SetReturnType(LogicalType::STRUCT(std::move(children))); return std::move(result); } @@ -1156,11 +1158,13 @@ void DecodeSortKeyRecursive(DecodeSortKeyData decode_data[], DecodeSortKeyVector } // namespace -void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, - OrderModifiers modifiers) { +idx_t CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, + OrderModifiers modifiers) { DecodeSortKeyVectorData sort_key_data(result.GetType(), modifiers); DecodeSortKeyData decode_data(sort_key); DecodeSortKeyRecursive(&decode_data, sort_key_data, result, result_idx, 1); + + return decode_data.position; } void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, @@ -1209,13 +1213,13 @@ static void DecodeSortKeyFunction(DataChunk &args, ExpressionState &state, Vecto for (idx_t i = 0; i < count; i++) { const auto idx = sort_key_vec_format.sel->get_index(i); D_ASSERT(sort_key_vec_format.validity.RowIsValid(idx)); - bswapped_ints[i] = BSwap(sort_keys[idx]); + bswapped_ints[i] = BSwapIfLE(sort_keys[idx]); decode_data[i] = DecodeSortKeyData(bswapped_ints[i]); } } else { for (idx_t i = 0; i < count; i++) { D_ASSERT(sort_key_vec_format.validity.RowIsValid(i)); - bswapped_ints[i] = BSwap(sort_keys[i]); + bswapped_ints[i] = BSwapIfLE(sort_keys[i]); decode_data[i] = DecodeSortKeyData(bswapped_ints[i]); } } @@ -1242,7 +1246,7 @@ ScalarFunction CreateSortKeyFun::GetFunction() { ScalarFunction sort_key_function("create_sort_key", {LogicalType::ANY}, LogicalType::BLOB, CreateSortKeyFunction, CreateSortKeyBind); sort_key_function.varargs = LogicalType::ANY; - sort_key_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + sort_key_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return sort_key_function; } diff --git a/src/duckdb/src/function/scalar/date/strftime.cpp b/src/duckdb/src/function/scalar/date/strftime.cpp index 66a044f34..913bcd28e 100644 --- a/src/duckdb/src/function/scalar/date/strftime.cpp +++ b/src/duckdb/src/function/scalar/date/strftime.cpp @@ -148,7 +148,6 @@ inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, timestamp } struct StrpTimeFunction { - template static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); @@ -225,13 +224,13 @@ struct StrpTimeFunction { error); } if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); } else if (format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED)) { - bound_function.return_type = LogicalType::TIMESTAMP_NS; + bound_function.SetReturnType(LogicalType::TIMESTAMP_NS); if (bound_function.name == "strptime") { - bound_function.function = Parse; + bound_function.SetFunctionCallback(Parse); } else { - bound_function.function = TryParse; + bound_function.SetFunctionCallback(TryParse); } } return make_uniq(format, format_string); @@ -261,15 +260,15 @@ struct StrpTimeFunction { if (has_offset) { // If any format has UTC offsets, then we have to produce TSTZ - bound_function.return_type = LogicalType::TIMESTAMP_TZ; + bound_function.SetReturnType(LogicalType::TIMESTAMP_TZ); } else if (has_nanos) { // If any format has nanoseconds, then we have to produce TSNS // unless there is an offset, in which case we produce - bound_function.return_type = LogicalType::TIMESTAMP_NS; + bound_function.SetReturnType(LogicalType::TIMESTAMP_NS); if (bound_function.name == "strptime") { - bound_function.function = Parse; + bound_function.SetFunctionCallback(Parse); } else { - bound_function.function = TryParse; + bound_function.SetFunctionCallback(TryParse); } } return make_uniq(formats, format_strings); @@ -304,14 +303,14 @@ ScalarFunctionSet StrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetFallible(); strptime.AddFunction(fun); fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::Parse, StrpTimeFunction::Bind); - BaseScalarFunction::SetReturnsError(fun); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetFallible(); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); strptime.AddFunction(fun); return strptime; } @@ -322,12 +321,12 @@ ScalarFunctionSet TryStrpTimeFun::GetFunctions() { const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); try_strptime.AddFunction(fun); fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); try_strptime.AddFunction(fun); return try_strptime; diff --git a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp index c5c4307bd..ddaf25649 100644 --- a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp +++ b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp @@ -81,7 +81,7 @@ unique_ptr ConstantOrNullBind(ClientContext &context, ScalarFuncti } D_ASSERT(arguments.size() >= 2); auto value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); return make_uniq(std::move(value)); } @@ -104,7 +104,7 @@ bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value ScalarFunction ConstantOrNullFun::GetFunction() { auto fun = ScalarFunction("constant_or_null", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ConstantOrNullFunction); - fun.bind = ConstantOrNullBind; + fun.SetBindCallback(ConstantOrNullBind); fun.varargs = LogicalType::ANY; return fun; } diff --git a/src/duckdb/src/function/scalar/generic/error.cpp b/src/duckdb/src/function/scalar/generic/error.cpp index 30d2a5f13..f2847786c 100644 --- a/src/duckdb/src/function/scalar/generic/error.cpp +++ b/src/duckdb/src/function/scalar/generic/error.cpp @@ -26,8 +26,8 @@ static void ErrorFunction(DataChunk &args, ExpressionState &state, Vector &resul ScalarFunction ErrorFun::GetFunction() { auto fun = ScalarFunction("error", {LogicalType::VARCHAR}, LogicalType::SQLNULL, ErrorFunction); // Set the function with side effects to avoid the optimization. - fun.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(fun); + fun.SetVolatile(); + fun.SetFallible(); return fun; } diff --git a/src/duckdb/src/function/scalar/generic/getvariable.cpp b/src/duckdb/src/function/scalar/generic/getvariable.cpp index 52c63488f..b5b9eb013 100644 --- a/src/duckdb/src/function/scalar/generic/getvariable.cpp +++ b/src/duckdb/src/function/scalar/generic/getvariable.cpp @@ -36,7 +36,7 @@ unique_ptr GetVariableBind(ClientContext &context, ScalarFunction if (!variable_name.IsNull()) { ClientConfig::GetConfig(context).GetUserVariable(variable_name.ToString(), value); } - function.return_type = value.type(); + function.SetReturnType(value.type()); return make_uniq(std::move(value)); } @@ -54,7 +54,7 @@ unique_ptr BindGetVariableExpression(FunctionBindExpressionInput &in ScalarFunction GetVariableFun::GetFunction() { ScalarFunction getvar("getvariable", {LogicalType::VARCHAR}, LogicalType::ANY, nullptr, GetVariableBind, nullptr); - getvar.bind_expression = BindGetVariableExpression; + getvar.SetBindExpressionCallback(BindGetVariableExpression); return getvar; } diff --git a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp new file mode 100644 index 000000000..18a259460 --- /dev/null +++ b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp @@ -0,0 +1,65 @@ +#include "duckdb/function/scalar/geometry_functions.hpp" +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/common/vector_operations/binary_executor.hpp" + +namespace duckdb { + +static void FromWKBFunction(DataChunk &input, ExpressionState &state, Vector &result) { + Geometry::FromBinary(input.data[0], result, input.size(), true); +} + +ScalarFunction StGeomfromwkbFun::GetFunction() { + ScalarFunction function({LogicalType::BLOB}, LogicalType::GEOMETRY(), FromWKBFunction); + return function; +} + +static void ToWKBFunction(DataChunk &input, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](const string_t &geom) { + // TODO: convert to internal representation + return geom; + }); + // Add a heap reference to the input WKB to prevent it from being freed + StringVector::AddHeapReference(input.data[0], result); +} + +ScalarFunction StAswkbFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY()}, LogicalType::BLOB, ToWKBFunction); + return function; +} + +static void ToWKTFunction(DataChunk &input, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(input.data[0], result, input.size(), + [&](const string_t &geom) { return Geometry::ToString(result, geom); }); +} + +ScalarFunction StAstextFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY()}, LogicalType::VARCHAR, ToWKTFunction); + return function; +} + +static void IntersectsExtentFunction(DataChunk &input, ExpressionState &state, Vector &result) { + BinaryExecutor::Execute( + input.data[0], input.data[1], result, input.size(), [](const string_t &lhs_geom, const string_t &rhs_geom) { + auto lhs_extent = GeometryExtent::Empty(); + auto rhs_extent = GeometryExtent::Empty(); + + const auto lhs_is_empty = Geometry::GetExtent(lhs_geom, lhs_extent) == 0; + const auto rhs_is_empty = Geometry::GetExtent(rhs_geom, rhs_extent) == 0; + + if (lhs_is_empty || rhs_is_empty) { + // One of the geometries is empty + return false; + } + + // Don't take Z and M into account for intersection test + return lhs_extent.IntersectsXY(rhs_extent); + }); +} + +ScalarFunction StIntersectsExtentFun::GetFunction() { + ScalarFunction function({LogicalType::GEOMETRY(), LogicalType::GEOMETRY()}, LogicalType::BOOLEAN, + IntersectsExtentFunction); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index 064bd4b00..bd4a2de51 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -34,7 +34,7 @@ ScalarFunction ListContainsFun::GetFunction() { ScalarFunction ListPositionFun::GetFunction() { auto fun = ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, LogicalType::INTEGER, ListSearchFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp index fd79249d9..d4ed220dd 100644 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -157,8 +157,8 @@ ScalarFunctionSet ListExtractFun::GetFunctions() { LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); - BaseScalarFunction::SetReturnsError(lfun); - BaseScalarFunction::SetReturnsError(sfun); + lfun.SetFallible(); + sfun.SetFallible(); list_extract_set.AddFunction(lfun); list_extract_set.AddFunction(sfun); return list_extract_set; diff --git a/src/duckdb/src/function/scalar/list/list_intersect.cpp b/src/duckdb/src/function/scalar/list/list_intersect.cpp new file mode 100644 index 000000000..e8bfef57a --- /dev/null +++ b/src/duckdb/src/function/scalar/list/list_intersect.cpp @@ -0,0 +1,197 @@ +#include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/function/scalar/list_functions.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/function/create_sort_key.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +static idx_t CalculateMaxResultLength(idx_t row_count, const UnifiedVectorFormat &l_format, + const UnifiedVectorFormat &r_format, const list_entry_t *l_entries, + const list_entry_t *r_entries) { + idx_t max_result_length = 0; + for (idx_t i = 0; i < row_count; i++) { + const auto l_idx = l_format.sel->get_index(i); + const auto r_idx = r_format.sel->get_index(i); + + if (l_format.validity.RowIsValid(l_idx) && r_format.validity.RowIsValid(r_idx)) { + const auto &l_list = l_entries[l_idx]; + const auto &r_list = r_entries[r_idx]; + max_result_length += MinValue(l_list.length, r_list.length); + } + } + return max_result_length; +} + +static void ListIntersectFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto row_count = args.size(); + + // Handle NULL return type case + if (result.GetType() == LogicalType::SQLNULL) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, true); + return; + } + + auto &l_vec = args.data[0]; + auto &r_vec = args.data[1]; + + const auto l_size = ListVector::GetListSize(l_vec); + const auto r_size = ListVector::GetListSize(r_vec); + + auto &l_child = ListVector::GetEntry(l_vec); + auto &r_child = ListVector::GetEntry(r_vec); + + const auto current_left_child_type = l_child.GetType(); + + UnifiedVectorFormat l_format; + UnifiedVectorFormat r_format; + + l_vec.ToUnifiedFormat(row_count, l_format); + r_vec.ToUnifiedFormat(row_count, r_format); + + const auto l_entries = UnifiedVectorFormat::GetData(l_format); + const auto r_entries = UnifiedVectorFormat::GetData(r_format); + + UnifiedVectorFormat l_child_format; + UnifiedVectorFormat r_child_format; + + l_child.ToUnifiedFormat(l_size, l_child_format); + r_child.ToUnifiedFormat(r_size, r_child_format); + + Vector l_sortkey_vec(LogicalType::BLOB, l_size); + Vector r_sortkey_vec(LogicalType::BLOB, r_size); + + const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + + CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + + const auto l_sortkey_ptr = FlatVector::GetData(l_sortkey_vec); + const auto r_sortkey_ptr = FlatVector::GetData(r_sortkey_vec); + + auto *result_data = FlatVector::GetData(result); + auto &result_entry = ListVector::GetEntry(result); + + string_set_t set; + string_set_t result_set; + string_map_t key_to_index_map; + + const idx_t max_result_length = CalculateMaxResultLength(row_count, l_format, r_format, l_entries, r_entries); + + ListVector::Reserve(result, max_result_length); + ListVector::SetListSize(result, max_result_length); + + SelectionVector result_sel(max_result_length); + ValidityMask result_entry_validity_mask(max_result_length); + idx_t offset = 0; + + auto &result_validity = FlatVector::Validity(result); + for (idx_t i = 0; i < row_count; i++) { + const auto l_idx = l_format.sel->get_index(i); + const auto r_idx = r_format.sel->get_index(i); + + const bool l_valid = l_format.validity.RowIsValid(l_idx); + const bool r_valid = r_format.validity.RowIsValid(r_idx); + + result_data[i].offset = offset; + + if (!l_valid) { + result_validity.SetInvalid(i); + result_data[i].length = 0; + continue; + } + if (!r_valid) { + result_data[i].length = 0; + continue; + } + + const auto &l_list = l_entries[l_idx]; + const auto &r_list = r_entries[r_idx]; + + if (l_list.length == 0 || r_list.length == 0) { + result_data[i].length = 0; + continue; + } + + set.clear(); + result_set.clear(); + key_to_index_map.clear(); + + // Choose which side to hash and which to iterate + const bool use_l_for_hash = l_list.length <= r_list.length; + const auto &hash_list = use_l_for_hash ? l_list : r_list; + const auto &iter_list = use_l_for_hash ? r_list : l_list; + const auto &hash_fmt = use_l_for_hash ? l_child_format : r_child_format; + const auto &iter_fmt = use_l_for_hash ? r_child_format : l_child_format; + const auto *hash_keys = use_l_for_hash ? l_sortkey_ptr : r_sortkey_ptr; + const auto *iter_keys = use_l_for_hash ? r_sortkey_ptr : l_sortkey_ptr; + + set.clear(); + key_to_index_map.clear(); + for (idx_t j = 0; j < hash_list.length; j++) { + const idx_t h_idx = hash_list.offset + j; + const idx_t h_entry = hash_fmt.sel->get_index(h_idx); + if (!hash_fmt.validity.RowIsValid(h_entry)) { + continue; + } + const auto &key = hash_keys[h_entry]; + set.insert(key); + if (use_l_for_hash) { + key_to_index_map[key] = h_idx; + } + } + + // Iterate the chosen side, but ALWAYS emit a LEFT index + result_set.clear(); + idx_t row_result_length = 0; + for (idx_t j = 0; j < iter_list.length; j++) { + const idx_t it_idx = iter_list.offset + j; + const idx_t it_entry = iter_fmt.sel->get_index(it_idx); + if (!iter_fmt.validity.RowIsValid(it_entry)) { + continue; + } + + const auto &key = iter_keys[it_entry]; + if (set.find(key) == set.end() || result_set.find(key) != result_set.end()) { + continue; + } + result_set.insert(key); + + const idx_t emit_left_idx = use_l_for_hash ? key_to_index_map[key] : it_idx; + + result_sel.set_index(offset + row_result_length, emit_left_idx); + row_result_length++; + } + + result_data[i].length = row_result_length; + offset += row_result_length; + } + + ListVector::SetListSize(result, offset); + + result_entry.Slice(l_child, result_sel, offset); + result_entry.Flatten(offset); + FlatVector::SetValidity(result_entry, result_entry_validity_mask); + + result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); +} +static unique_ptr ListIntersectBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + D_ASSERT(bound_function.arguments.size() == 2); + arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); + arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); + return nullptr; +} + +ScalarFunction ListIntersectFun::GetFunction() { + auto fun = + ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::LIST(LogicalType::TEMPLATE("T"))}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), ListIntersectFunction, ListIntersectBind); + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp index d159a7204..19fd149e3 100644 --- a/src/duckdb/src/function/scalar/list/list_resize.cpp +++ b/src/duckdb/src/function/scalar/list/list_resize.cpp @@ -8,7 +8,6 @@ namespace duckdb { static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &result) { - // Early-out, if the return value is a constant NULL. if (result.GetType().id() == LogicalTypeId::SQLNULL) { result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -63,7 +62,6 @@ static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &resul idx_t offset = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto list_idx = lists_data.sel->get_index(row_idx); auto new_size_idx = new_sizes_data.sel->get_index(row_idx); @@ -134,14 +132,14 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun // Early-out, if the first argument is a constant NULL. if (arguments[0]->return_type == LogicalType::SQLNULL) { bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::SQLNULL); + return make_uniq(bound_function.GetReturnType()); } // Early-out, if the first argument is a prepared statement. if (arguments[0]->return_type == LogicalType::UNKNOWN) { - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(arguments[0]->return_type); + return make_uniq(bound_function.GetReturnType()); } // Attempt implicit casting, if the default type does not match list the list child type. @@ -151,19 +149,19 @@ static unique_ptr ListResizeBind(ClientContext &context, ScalarFun bound_function.arguments[2] = ListType::GetChildType(arguments[0]->return_type); } - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(arguments[0]->return_type); + return make_uniq(bound_function.GetReturnType()); } ScalarFunctionSet ListResizeFun::GetFunctions() { ScalarFunction simple_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - simple_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(simple_fun); + simple_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + simple_fun.SetFallible(); ScalarFunction default_value_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - default_value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(default_value_fun); + default_value_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + default_value_fun.SetFallible(); ScalarFunctionSet list_resize_set("list_resize"); list_resize_set.AddFunction(simple_fun); list_resize_set.AddFunction(default_value_fun); diff --git a/src/duckdb/src/function/scalar/list/list_zip.cpp b/src/duckdb/src/function/scalar/list/list_zip.cpp index ef39a989d..2f83b61d6 100644 --- a/src/duckdb/src/function/scalar/list/list_zip.cpp +++ b/src/duckdb/src/function/scalar/list/list_zip.cpp @@ -155,15 +155,14 @@ static unique_ptr ListZipBind(ClientContext &context, ScalarFuncti throw BinderException("Parameter type needs to be List"); } } - bound_function.return_type = LogicalType::LIST(LogicalType::STRUCT(struct_children)); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::LIST(LogicalType::STRUCT(struct_children))); + return make_uniq(bound_function.GetReturnType()); } ScalarFunction ListZipFun::GetFunction() { - auto fun = ScalarFunction({}, LogicalType::LIST(LogicalTypeId::STRUCT), ListZipFunction, ListZipBind); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/nested_functions.cpp b/src/duckdb/src/function/scalar/nested_functions.cpp index 2d5359c4e..b09f04275 100644 --- a/src/duckdb/src/function/scalar/nested_functions.cpp +++ b/src/duckdb/src/function/scalar/nested_functions.cpp @@ -3,21 +3,22 @@ namespace duckdb { void MapUtil::ReinterpretMap(Vector &result, Vector &input, idx_t count) { + // Copy the list size + const auto list_size = ListVector::GetListSize(input); + ListVector::SetListSize(result, list_size); + UnifiedVectorFormat input_data; input.ToUnifiedFormat(count, input_data); + // Copy the list validity FlatVector::SetValidity(result, input_data.validity); // Copy the struct validity UnifiedVectorFormat input_struct_data; - ListVector::GetEntry(input).ToUnifiedFormat(count, input_struct_data); + ListVector::GetEntry(input).ToUnifiedFormat(list_size, input_struct_data); auto &result_struct = ListVector::GetEntry(result); FlatVector::SetValidity(result_struct, input_struct_data.validity); - // Copy the list size - auto list_size = ListVector::GetListSize(input); - ListVector::SetListSize(result, list_size); - // Copy the list buffer (the list_entry_t data) result.CopyBuffer(input); diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index 82cd9b5b7..83224332b 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -179,7 +179,7 @@ unique_ptr PropagateNumericStats(ClientContext &context, Functio auto &bind_data = input.bind_data->Cast(); bind_data.check_overflow = false; } - expr.function.function = GetScalarIntegerFunction(expr.return_type.InternalType()); + expr.function.SetFunctionCallback(GetScalarIntegerFunction(expr.return_type.InternalType())); } auto result = NumericStats::CreateEmpty(expr.return_type); NumericStats::SetMin(result, new_min); @@ -239,7 +239,7 @@ unique_ptr BindDecimalArithmetic(ClientContext &conte bound_function.arguments[i] = result_type; } } - bound_function.return_type = result_type; + bound_function.SetReturnType(result_type); return bind_data; } @@ -249,18 +249,19 @@ unique_ptr BindDecimalAddSubtract(ClientContext &context, ScalarFu auto bind_data = BindDecimalArithmetic(context, bound_function, arguments); // now select the physical function to execute - auto &result_type = bound_function.return_type; + auto &result_type = bound_function.GetReturnType(); if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(result_type.InternalType())); } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(result_type.InternalType())); } if (result_type.InternalType() != PhysicalType::INT128 && result_type.InternalType() != PhysicalType::UINT128) { if (IS_SUBTRACT) { - bound_function.statistics = - PropagateNumericStats; + bound_function.SetStatisticsCallback( + PropagateNumericStats); } else { - bound_function.statistics = PropagateNumericStats; + bound_function.SetStatisticsCallback( + PropagateNumericStats); } } return std::move(bind_data); @@ -270,25 +271,24 @@ void SerializeDecimalArithmetic(Serializer &serializer, const optional_ptrCast(); serializer.WriteProperty(100, "check_overflow", bind_data.check_overflow); - serializer.WriteProperty(101, "return_type", function.return_type); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); serializer.WriteProperty(102, "arguments", function.arguments); } // TODO this is partially duplicated from the bind template unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer, ScalarFunction &bound_function) { - // // re-change the function pointers auto check_overflow = deserializer.ReadProperty(100, "check_overflow"); auto return_type = deserializer.ReadProperty(101, "return_type"); auto arguments = deserializer.ReadProperty>(102, "arguments"); if (check_overflow) { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(return_type.InternalType())); } else { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(return_type.InternalType())); } - bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again - bound_function.return_type = return_type; + bound_function.SetStatisticsCallback(nullptr); // TODO we likely dont want to do stats prop again + bound_function.SetReturnType(return_type); bound_function.arguments = arguments; auto bind_data = make_uniq(); @@ -298,7 +298,7 @@ unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer unique_ptr NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - bound_function.return_type = arguments[0]->return_type; + bound_function.SetReturnType(arguments[0]->return_type); bound_function.arguments[0] = arguments[0]->return_type; return nullptr; } @@ -353,21 +353,21 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (left_type.id() == LogicalTypeId::DECIMAL) { auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr, BindDecimalAddSubtract); - BaseScalarFunction::SetReturnsError(function); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; + function.SetFallible(); + function.SetSerializeCallback(SerializeDecimalArithmetic); + function.SetDeserializeCallback(DeserializeDecimalArithmetic); return function; } else if (left_type.IsIntegral()) { ScalarFunction function("+", {left_type, right_type}, left_type, GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, PropagateNumericStats); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else { ScalarFunction function("+", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } } @@ -376,7 +376,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi case LogicalTypeId::BIGNUM: if (right_type.id() == LogicalTypeId::BIGNUM) { ScalarFunction function("+", {left_type, right_type}, LogicalType::BIGNUM, BignumAdd); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -385,22 +385,22 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTEGER) { ScalarFunction function("+", {left_type, right_type}, LogicalType::DATE, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME_TZ) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -408,7 +408,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, right_type, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -416,28 +416,28 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIME_TZ) { ScalarFunction function( "+", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -445,12 +445,12 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -458,13 +458,13 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function( "+", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -472,7 +472,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -589,24 +589,27 @@ struct DecimalNegateBindData : public FunctionData { unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto bind_data = make_uniq(); auto &decimal_type = arguments[0]->return_type; auto width = DecimalType::GetWidth(decimal_type); if (width <= Decimal::MAX_WIDTH_INT16) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); + bound_function.SetFunctionCallback( + ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT)); } else if (width <= Decimal::MAX_WIDTH_INT32) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); + bound_function.SetFunctionCallback( + ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER)); } else if (width <= Decimal::MAX_WIDTH_INT64) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); + bound_function.SetFunctionCallback( + ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT)); } else { D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); + bound_function.SetFunctionCallback( + ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT)); } decimal_type.Verify(); bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; + bound_function.SetReturnType(decimal_type); return nullptr; } @@ -672,7 +675,7 @@ unique_ptr NegateBindStatistics(ClientContext &context, Function ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { if (type.id() == LogicalTypeId::INTERVAL) { ScalarFunction func("-", {type}, type, ScalarFunction::UnaryFunction); - ScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } else if (type.id() == LogicalTypeId::DECIMAL) { ScalarFunction func("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); @@ -684,7 +687,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { D_ASSERT(type.IsNumeric()); ScalarFunction func("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, nullptr, NegateBindStatistics); - ScalarFunction::SetReturnsError(func); + func.SetFallible(); return func; } } @@ -694,22 +697,23 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (left_type.id() == LogicalTypeId::DECIMAL) { ScalarFunction function("-", {left_type, right_type}, left_type, nullptr, BindDecimalAddSubtract); - ScalarFunction::SetReturnsError(function); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; + function.SetFallible(); + function.SetSerializeCallback(SerializeDecimalArithmetic); + function.SetDeserializeCallback( + DeserializeDecimalArithmetic); return function; } else if (left_type.IsIntegral()) { ScalarFunction function( "-", {left_type, right_type}, left_type, GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, PropagateNumericStats); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else { ScalarFunction function("-", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } } @@ -723,18 +727,18 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("-", {left_type, right_type}, LogicalType::BIGINT, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTEGER) { ScalarFunction function("-", {left_type, right_type}, LogicalType::DATE, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("-", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -743,13 +747,13 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } else if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function( "-", {left_type, right_type}, LogicalType::TIMESTAMP, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -758,7 +762,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -766,7 +770,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const if (right_type.id() == LogicalTypeId::INTERVAL) { ScalarFunction function("-", {left_type, right_type}, LogicalType::TIME, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -775,7 +779,7 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const ScalarFunction function( "-", {left_type, right_type}, LogicalType::TIME_TZ, ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); + function.SetFallible(); return function; } break; @@ -861,7 +865,6 @@ struct MultiplyPropagateStatistics { unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto bind_data = make_uniq(); uint8_t result_width = 0, result_scale = 0; @@ -915,16 +918,17 @@ unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunct } } result_type.Verify(); - bound_function.return_type = result_type; + bound_function.SetReturnType(result_type); // now select the physical function to execute if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + bound_function.SetFunctionCallback( + GetScalarBinaryFunction(result_type.InternalType())); } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(result_type.InternalType())); } if (result_type.InternalType() != PhysicalType::INT128) { - bound_function.statistics = - PropagateNumericStats; + bound_function.SetStatisticsCallback( + PropagateNumericStats); } return std::move(bind_data); } @@ -936,8 +940,9 @@ ScalarFunctionSet OperatorMultiplyFun::GetFunctions() { for (auto &type : LogicalType::Numeric()) { if (type.id() == LogicalTypeId::DECIMAL) { ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; + function.SetSerializeCallback(SerializeDecimalArithmetic); + function.SetDeserializeCallback( + DeserializeDecimalArithmetic); multiply.AddFunction(function); } else if (TypeIsIntegral(type.InternalType())) { multiply.AddFunction(ScalarFunction( @@ -962,7 +967,7 @@ ScalarFunctionSet OperatorMultiplyFun::GetFunctions() { ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, ScalarFunction::BinaryFunction)); for (auto &func : multiply.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return multiply; @@ -1096,9 +1101,10 @@ template unique_ptr BindBinaryFloatingPoint(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (DBConfig::GetSetting(context)) { - bound_function.function = GetScalarBinaryFunction(bound_function.return_type.InternalType()); + bound_function.SetFunctionCallback(GetScalarBinaryFunction(bound_function.GetReturnType().InternalType())); } else { - bound_function.function = GetBinaryFunctionIgnoreZero(bound_function.return_type.InternalType()); + bound_function.SetFunctionCallback( + GetBinaryFunctionIgnoreZero(bound_function.GetReturnType().InternalType())); } return nullptr; } @@ -1114,7 +1120,7 @@ ScalarFunctionSet OperatorFloatDivideFun::GetFunctions() { ScalarFunction({LogicalType::INTERVAL, LogicalType::DOUBLE}, LogicalType::INTERVAL, BinaryScalarFunctionIgnoreZero)); for (auto &func : fp_divide.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return fp_divide; } @@ -1130,7 +1136,7 @@ ScalarFunctionSet OperatorIntegerDivideFun::GetFunctions() { } } for (auto &func : full_divide.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return full_divide; } @@ -1148,10 +1154,10 @@ static unique_ptr BindDecimalModulo(ClientContext &context, Scalar for (auto &arg : bound_function.arguments) { arg = LogicalType::DOUBLE; } - bound_function.return_type = LogicalType::DOUBLE; + bound_function.SetReturnType(LogicalType::DOUBLE); } - auto &result_type = bound_function.return_type; - bound_function.function = GetBinaryFunctionIgnoreZero(result_type.InternalType()); + auto &result_type = bound_function.GetReturnType(); + bound_function.SetFunctionCallback(GetBinaryFunctionIgnoreZero(result_type.InternalType())); return std::move(bind_data); } @@ -1188,7 +1194,7 @@ ScalarFunctionSet OperatorModuloFun::GetFunctions() { } } for (auto &func : modulo.functions) { - ScalarFunction::SetReturnsError(func); + func.SetFallible(); } return modulo; @@ -1220,7 +1226,7 @@ hugeint_t InterpolateOperator::Operation(const hugeint_t &lo, const double d, co template <> uhugeint_t InterpolateOperator::Operation(const uhugeint_t &lo, const double d, const uhugeint_t &hi) { - return Hugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); + return Uhugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); } static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp index c053bb7f6..8a53af7bb 100644 --- a/src/duckdb/src/function/scalar/sequence/nextval.cpp +++ b/src/duckdb/src/function/scalar/sequence/nextval.cpp @@ -132,7 +132,7 @@ void NextValModifiedDatabases(ClientContext &context, FunctionModifiedDatabasesI return; } auto &seq = input.bind_data->Cast(); - input.properties.RegisterDBModify(seq.sequence.ParentCatalog(), context); + input.properties.RegisterDBModify(seq.sequence.ParentCatalog(), context, DatabaseModificationType::SEQUENCE); } } // namespace @@ -140,25 +140,25 @@ void NextValModifiedDatabases(ClientContext &context, FunctionModifiedDatabasesI ScalarFunction NextvalFun::GetFunction() { ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT, NextValFunction, nullptr, nullptr); - next_val.bind_extended = NextValBind; - next_val.stability = FunctionStability::VOLATILE; - next_val.serialize = Serialize; - next_val.deserialize = Deserialize; - next_val.get_modified_databases = NextValModifiedDatabases; - next_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(next_val); + next_val.SetBindExtendedCallback(NextValBind); + next_val.SetSerializeCallback(Serialize); + next_val.SetDeserializeCallback(Deserialize); + next_val.SetModifiedDatabasesCallback(NextValModifiedDatabases); + next_val.SetInitStateCallback(NextValLocalFunction); + next_val.SetVolatile(); + next_val.SetFallible(); return next_val; } ScalarFunction CurrvalFun::GetFunction() { ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT, NextValFunction, nullptr, nullptr); - curr_val.bind_extended = NextValBind; - curr_val.stability = FunctionStability::VOLATILE; - curr_val.serialize = Serialize; - curr_val.deserialize = Deserialize; - curr_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(curr_val); + curr_val.SetBindExtendedCallback(NextValBind); + curr_val.SetSerializeCallback(Serialize); + curr_val.SetDeserializeCallback(Deserialize); + curr_val.SetInitStateCallback(NextValLocalFunction); + curr_val.SetVolatile(); + curr_val.SetFallible(); return curr_val; } diff --git a/src/duckdb/src/function/scalar/string/caseconvert.cpp b/src/duckdb/src/function/scalar/string/caseconvert.cpp index 9bd369849..0309b79f7 100644 --- a/src/duckdb/src/function/scalar/string/caseconvert.cpp +++ b/src/duckdb/src/function/scalar/string/caseconvert.cpp @@ -135,7 +135,7 @@ static unique_ptr CaseConvertPropagateStats(ClientContext &conte D_ASSERT(child_stats.size() == 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = CaseConvertFunctionASCII; + expr.function.SetFunctionCallback(CaseConvertFunctionASCII); } return nullptr; } diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp index cae184de9..97a74cebe 100644 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -208,9 +208,13 @@ void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result, void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); auto &info = func_expr.bind_info->Cast(); + if (info.return_type.id() == LogicalTypeId::SQLNULL) { + return; + } if (info.return_type.id() == LogicalTypeId::LIST) { return ListConcatFunction(args, state, result, info.is_operator); - } else if (info.is_operator) { + } + if (info.is_operator) { return ConcatOperator(args, state, result); } return StringConcatFunction(args, state, result); @@ -220,7 +224,7 @@ void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bo if (is_operator) { bound_function.arguments[0] = type; bound_function.arguments[1] = type; - bound_function.return_type = type; + bound_function.SetReturnType(type); return; } @@ -228,7 +232,7 @@ void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bo arg = type; } bound_function.varargs = type; - bound_function.return_type = type; + bound_function.SetReturnType(type); } unique_ptr BindListConcat(ClientContext &context, ScalarFunction &bound_function, @@ -277,17 +281,18 @@ unique_ptr BindListConcat(ClientContext &context, ScalarFunction & if (all_null) { // all arguments are NULL SetArgumentType(bound_function, LogicalTypeId::SQLNULL, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } auto list_type = LogicalType::LIST(child_type); SetArgumentType(bound_function, list_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } unique_ptr BindConcatFunctionInternal(ClientContext &context, ScalarFunction &bound_function, vector> &arguments, bool is_operator) { bool list_concat = false; + bool all_null = true; // blob concat is only supported for the concat operator - regular concat converts to varchar bool all_blob = is_operator ? true : false; for (auto &arg : arguments) { @@ -300,15 +305,18 @@ unique_ptr BindConcatFunctionInternal(ClientContext &context, Scal if (arg->return_type.id() != LogicalTypeId::BLOB) { all_blob = false; } + if (arg->return_type.id() != LogicalTypeId::SQLNULL) { + all_null = false; + } } - if (list_concat) { + if (list_concat || all_null) { return BindListConcat(context, bound_function, arguments, is_operator); } auto return_type = all_blob ? LogicalType::BLOB : LogicalType::VARCHAR; // we can now assume that the input is a string or castable to a string SetArgumentType(bound_function, return_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); + return make_uniq(bound_function.GetReturnType(), is_operator); } unique_ptr BindConcatFunction(ClientContext &context, ScalarFunction &bound_function, @@ -337,7 +345,7 @@ ScalarFunction ListConcatFun::GetFunction() { auto fun = ScalarFunction({}, LogicalType::LIST(LogicalType::ANY), ConcatFunction, BindConcatFunction, nullptr, ListConcatStats); fun.varargs = LogicalType::LIST(LogicalType::ANY); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } @@ -353,7 +361,7 @@ ScalarFunction ConcatFun::GetFunction() { ScalarFunction concat = ScalarFunction("concat", {LogicalType::ANY}, LogicalType::ANY, ConcatFunction, BindConcatFunction); concat.varargs = LogicalType::ANY; - concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + concat.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return concat; } diff --git a/src/duckdb/src/function/scalar/string/concat_ws.cpp b/src/duckdb/src/function/scalar/string/concat_ws.cpp index ebc1e8b3a..9b67878cd 100644 --- a/src/duckdb/src/function/scalar/string/concat_ws.cpp +++ b/src/duckdb/src/function/scalar/string/concat_ws.cpp @@ -142,7 +142,7 @@ ScalarFunction ConcatWsFun::GetFunction() { ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::ANY}, LogicalType::VARCHAR, ConcatWSFunction, BindConcatWSFunction); concat_ws.varargs = LogicalType::ANY; - concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + concat_ws.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return ScalarFunction(concat_ws); } diff --git a/src/duckdb/src/function/scalar/string/contains.cpp b/src/duckdb/src/function/scalar/string/contains.cpp index fb496b1fd..95c53ff01 100644 --- a/src/duckdb/src/function/scalar/string/contains.cpp +++ b/src/duckdb/src/function/scalar/string/contains.cpp @@ -121,7 +121,7 @@ idx_t FindStrInStr(const string_t &haystack_s, const string_t &needle_s) { ScalarFunction GetStringContains() { ScalarFunction string_fun("contains", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - string_fun.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + string_fun.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return string_fun; } diff --git a/src/duckdb/src/function/scalar/string/length.cpp b/src/duckdb/src/function/scalar/string/length.cpp index 66542af3c..4b5c50db7 100644 --- a/src/duckdb/src/function/scalar/string/length.cpp +++ b/src/duckdb/src/function/scalar/string/length.cpp @@ -63,7 +63,7 @@ unique_ptr LengthPropagateStats(ClientContext &context, Function D_ASSERT(child_stats.size() == 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::UnaryFunction; + expr.function.SetFunctionCallback(ScalarFunction::UnaryFunction); } return nullptr; } @@ -118,9 +118,9 @@ unique_ptr ArrayOrListLengthBind(ClientContext &context, ScalarFun const auto &arg_type = arguments[0]->return_type.id(); if (arg_type == LogicalTypeId::ARRAY) { - bound_function.function = ArrayLengthFunction; + bound_function.SetFunctionCallback(ArrayLengthFunction); } else if (arg_type == LogicalTypeId::LIST) { - bound_function.function = ListLengthFunction; + bound_function.SetFunctionCallback(ListLengthFunction); } else { // Unreachable throw BinderException("length can only be used on arrays or lists"); @@ -193,7 +193,7 @@ unique_ptr ArrayOrListLengthBinaryBind(ClientContext &context, Sca auto type = arguments[0]->return_type; if (type.id() == LogicalTypeId::ARRAY) { bound_function.arguments[0] = type; - bound_function.function = ArrayLengthBinaryFunction; + bound_function.SetFunctionCallback(ArrayLengthBinaryFunction); // If the input is an array, the dimensions are constant, so we can calculate them at bind time vector dimensions; @@ -210,7 +210,7 @@ unique_ptr ArrayOrListLengthBinaryBind(ClientContext &context, Sca return std::move(data); } else if (type.id() == LogicalTypeId::LIST) { - bound_function.function = ListLengthBinaryFunction; + bound_function.SetFunctionCallback(ListLengthBinaryFunction); bound_function.arguments[0] = type; return nullptr; } else { @@ -248,7 +248,7 @@ ScalarFunctionSet ArrayLengthFun::GetFunctions() { array_length.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::BIGINT, nullptr, ArrayOrListLengthBinaryBind)); for (auto &func : array_length.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return (array_length); } diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp index ba974f9d2..22e8b691a 100644 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -498,7 +498,7 @@ unique_ptr ILikePropagateStats(ClientContext &context, FunctionS D_ASSERT(child_stats.size() >= 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; + expr.function.SetFunctionCallback(ScalarFunction::BinaryFunction); } return nullptr; } @@ -524,14 +524,14 @@ void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &resul ScalarFunction NotLikeFun::GetFunction() { ScalarFunction not_like("!~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegularLikeFunction, LikeBindFunction); - not_like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_like.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_like; } ScalarFunction GlobPatternFun::GetFunction() { ScalarFunction glob("~~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction); - glob.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + glob.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return glob; } @@ -539,7 +539,7 @@ ScalarFunction ILikeFun::GetFunction() { ScalarFunction ilike("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction, nullptr, nullptr, ILikePropagateStats); - ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike; } @@ -547,14 +547,14 @@ ScalarFunction NotILikeFun::GetFunction() { ScalarFunction not_ilike("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, ScalarFunction::BinaryFunction, nullptr, nullptr, ILikePropagateStats); - not_ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike; } ScalarFunction LikeFun::GetFunction() { ScalarFunction like("~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegularLikeFunction, LikeBindFunction); - like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + like.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return like; } @@ -562,14 +562,14 @@ ScalarFunction NotLikeEscapeFun::GetFunction() { ScalarFunction not_like_escape("not_like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - not_like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_like_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_like_escape; } ScalarFunction IlikeEscapeFun::GetFunction() { ScalarFunction ilike_escape("ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike_escape; } @@ -577,13 +577,13 @@ ScalarFunction NotIlikeEscapeFun::GetFunction() { ScalarFunction not_ilike_escape("not_ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - not_ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + not_ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike_escape; } ScalarFunction LikeEscapeFun::GetFunction() { ScalarFunction like_escape("like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LikeEscapeFunction); - like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; + like_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return like_escape; } diff --git a/src/duckdb/src/function/scalar/string/regexp.cpp b/src/duckdb/src/function/scalar/string/regexp.cpp index f91121a07..347fdfaa0 100644 --- a/src/duckdb/src/function/scalar/string/regexp.cpp +++ b/src/duckdb/src/function/scalar/string/regexp.cpp @@ -245,6 +245,11 @@ static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector // Regexp Extract Struct //===--------------------------------------------------------------------===// static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { + // This function assumes a constant pre-compiled pattern stored in the local state. + // If a non-constant pattern reaches here it indicates a binder bug. Return a clean error instead of crashing. + if (!ExecuteFunctionState::GetFunctionState(state)) { + throw InternalException("REGEXP_EXTRACT struct variant executed without constant pattern state"); + } auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); const auto count = args.size(); @@ -346,32 +351,13 @@ static unique_ptr RegexExtractBind(ClientContext &context, ScalarF group_string = ""; } else if (group.type().id() == LogicalTypeId::LIST) { if (!constant_pattern) { - throw BinderException("%s with LIST requires a constant pattern", bound_function.name); - } - auto &list_children = ListValue::GetChildren(group); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of capture names", bound_function.name); + throw BinderException("%s with LIST of group names requires a constant pattern", bound_function.name); } - case_insensitive_set_t name_collision_set; + vector dummy_names; // not reused after bind child_list_t struct_children; - for (const auto &child : list_children) { - if (child.IsNull()) { - throw BinderException("NULL group name in %s", bound_function.name); - } - const auto group_name = child.ToString(); - if (name_collision_set.find(group_name) != name_collision_set.end()) { - throw BinderException("Duplicate group name \"%s\" in %s", group_name, bound_function.name); - } - name_collision_set.insert(group_name); - struct_children.emplace_back(make_pair(group_name, LogicalType::VARCHAR)); - } - bound_function.return_type = LogicalType::STRUCT(struct_children); - - duckdb_re2::StringPiece constant_piece(constant_string.c_str(), constant_string.size()); - RE2 constant_pattern(constant_piece, options); - if (size_t(constant_pattern.NumberOfCapturingGroups()) < list_children.size()) { - throw BinderException("Not enough group names in %s", bound_function.name); - } + regexp_util::ParseGroupNameList(context, bound_function.name, *arguments[2], constant_string, options, + constant_pattern, dummy_names, struct_children); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); } else { auto group_idx = group.GetValue(); if (group_idx < 0 || group_idx > 9) { @@ -409,7 +395,7 @@ ScalarFunctionSet RegexpMatchesFun::GetFunctions() { RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); for (auto &func : regexp_partial_match.functions) { - BaseScalarFunction::SetReturnsError(func); + func.SetFallible(); } return (regexp_partial_match); } @@ -467,6 +453,19 @@ ScalarFunctionSet RegexpExtractAllFun::GetFunctions() { LogicalType::LIST(LogicalType::VARCHAR), RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); + // Struct multi-match variant(s): pattern must be constant due to bind-time struct shape inference + regexp_extract_all.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, + LogicalType::LIST(LogicalType::VARCHAR), // temporary, replaced in bind + RegexpExtractAllStruct::Execute, RegexpExtractAllStruct::Bind, nullptr, nullptr, + RegexpExtractAllStruct::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING)); + regexp_extract_all.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR), LogicalType::VARCHAR}, + LogicalType::LIST(LogicalType::VARCHAR), // temporary, replaced in bind + RegexpExtractAllStruct::Execute, RegexpExtractAllStruct::Bind, nullptr, nullptr, + RegexpExtractAllStruct::InitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, + FunctionNullHandling::SPECIAL_HANDLING)); return (regexp_extract_all); } diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp index 144dcff03..151b7c599 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/function/scalar/string_functions.hpp" #include "re2/re2.h" +#include "re2/stringpiece.h" namespace duckdb { @@ -21,10 +22,19 @@ RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpr return nullptr; } +unique_ptr RegexpExtractAllStruct::InitLocalState(ExpressionState &state, + const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &info = bind_data->Cast(); + if (info.constant_pattern) { + return make_uniq(info, true); + } + return nullptr; +} + // Forwards startpos automatically bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, duckdb_re2::StringPiece *groups, int ngroups) { - D_ASSERT(pattern.ok()); D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); @@ -33,13 +43,8 @@ bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t } idx_t consumed = static_cast(groups[0].end() - (input.begin() + *startpos)); if (!consumed) { - // Empty match found, have to manually forward the input - // to avoid an infinite loop - // FIXME: support unicode characters - consumed++; - while (*startpos + consumed < input.length() && !IsCharacter(input[*startpos + consumed])) { - consumed++; - } + // Empty match: advance exactly one UTF-8 codepoint + consumed = regexp_util::AdvanceOneUTF8Basic(input, *startpos); } *startpos += consumed; return true; @@ -228,6 +233,136 @@ void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector & } } +static inline bool ExtractAllStruct(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re, idx_t &startpos, + duckdb_re2::StringPiece *groups, int provided_groups) { + D_ASSERT(re.ok()); + if (!re.Match(input, startpos, input.size(), re.UNANCHORED, groups, provided_groups + 1)) { + return false; + } + idx_t consumed = static_cast(groups[0].end() - (input.begin() + startpos)); + if (!consumed) { + consumed = regexp_util::AdvanceOneUTF8Basic(input, startpos); + } + startpos += consumed; + return true; +} + +static void ExtractStructAllSingleTuple(const string_t &string_val, duckdb_re2::RE2 &re, + vector &group_spans, + vector> &child_entries, Vector &result, idx_t row) { + const idx_t group_count = child_entries.size(); + auto list_entries = FlatVector::GetData(result); + idx_t current_list_size = ListVector::GetListSize(result); + list_entries[row].offset = current_list_size; + + auto input_piece = CreateStringPiece(string_val); + idx_t startpos = 0; + for (; ExtractAllStruct(input_piece, re, startpos, group_spans.data(), UnsafeNumericCast(group_count));) { + // Ensure capacity + if (current_list_size + 1 >= ListVector::GetListCapacity(result)) { + ListVector::Reserve(result, ListVector::GetListCapacity(result) * 2); + } + // Write each selected group + for (idx_t g = 0; g < group_count; g++) { + auto &child_vec = *child_entries[g]; + child_vec.SetVectorType(VectorType::FLAT_VECTOR); + auto cdata = FlatVector::GetData(child_vec); + auto &span = group_spans[g + 1]; + if (span.empty()) { + if (span.begin() == nullptr) { + // Unmatched optional group -> always NULL + FlatVector::Validity(child_vec).SetInvalid(current_list_size); + } + cdata[current_list_size] = string_t(string_val.GetData(), 0); + } else { + auto offset = span.begin() - string_val.GetData(); + cdata[current_list_size] = + string_t(string_val.GetData() + offset, UnsafeNumericCast(span.size())); + } + } + current_list_size++; + if (startpos > input_piece.size()) { + break; // empty match at end + } + } + list_entries[row].length = current_list_size - list_entries[row].offset; + ListVector::SetListSize(result, current_list_size); +} + +void RegexpExtractAllStruct::Execute(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + const auto &info = func_expr.bind_info->Cast(); + // Struct multi-match variant only supports constant pattern (enforced in Bind) + D_ASSERT(info.constant_pattern); + + // Expect arguments: string, pattern, list_of_group_names [, options] + auto &strings = args.data[0]; + + D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); + auto &struct_vector = ListVector::GetEntry(result); + D_ASSERT(struct_vector.GetType().id() == LogicalTypeId::STRUCT); + auto &child_entries = StructVector::GetEntries(struct_vector); + const idx_t group_count = child_entries.size(); + + // Reference original string buffer for zero-copy substring assignment + for (auto &child : child_entries) { + child->SetAuxiliary(strings.GetAuxiliary()); + child->SetVectorType(VectorType::FLAT_VECTOR); + } + + UnifiedVectorFormat strings_data; + strings.ToUnifiedFormat(args.size(), strings_data); + ListVector::Reserve(result, STANDARD_VECTOR_SIZE); + idx_t tuple_count = args.AllConstant() ? 1 : args.size(); + + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + + auto &list_validity = FlatVector::Validity(result); + auto list_entries = FlatVector::GetData(result); + + vector group_spans(group_count + 1); + + for (idx_t row = 0; row < tuple_count; row++) { + auto sindex = strings_data.sel->get_index(row); + if (!strings_data.validity.RowIsValid(sindex)) { + list_entries[row].offset = ListVector::GetListSize(result); + list_entries[row].length = 0; + list_validity.SetInvalid(row); + continue; + } + auto &string_val = UnifiedVectorFormat::GetData(strings_data)[sindex]; + ExtractStructAllSingleTuple(string_val, lstate.constant_pattern, group_spans, child_entries, result, row); + } + if (args.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } +} + +unique_ptr RegexpExtractAllStruct::Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + // arguments: string, pattern, LIST group_names [, options] + if (arguments.size() < 3) { + throw BinderException("regexp_extract_all struct variant requires at least 3 arguments"); + } + duckdb_re2::RE2::Options options; + string constant_string; + bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); + if (!constant_pattern) { + throw BinderException("%s with LIST requires a constant pattern", bound_function.name); + } + if (arguments.size() >= 4) { + ParseRegexOptions(context, *arguments[3], options); + } + options.set_log_errors(false); + vector group_names; + child_list_t struct_children; + regexp_util::ParseGroupNameList(context, bound_function.name, *arguments[2], constant_string, options, true, + group_names, struct_children); + bound_function.SetReturnType(LogicalType::LIST(LogicalType::STRUCT(struct_children))); + return make_uniq(options, std::move(constant_string), constant_pattern, + std::move(group_names)); +} + unique_ptr RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { D_ASSERT(arguments.size() >= 2); diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp index 4e485195c..2bac42104 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp @@ -1,5 +1,7 @@ #include "duckdb/function/scalar/regexp.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "re2/re2.h" +#include "re2/stringpiece.h" namespace duckdb { @@ -78,6 +80,76 @@ void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &t ParseRegexOptions(StringValue::Get(options_str), target, global_replace); } +void ParseGroupNameList(ClientContext &context, const string &function_name, Expression &group_expr, + const string &pattern_string, RE2::Options &options, bool require_constant_pattern, + vector &out_names, child_list_t &out_struct_children) { + if (group_expr.HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!group_expr.IsFoldable()) { + throw InvalidInputException("Group specification field must be a constant list"); + } + Value list_val = ExpressionExecutor::EvaluateScalar(context, group_expr); + if (list_val.IsNull() || list_val.type().id() != LogicalTypeId::LIST) { + throw BinderException("Group specification must be a non-NULL LIST"); + } + auto &children = ListValue::GetChildren(list_val); + if (children.empty()) { + throw BinderException("Group name list must be non-empty"); + } + case_insensitive_set_t name_set; + for (auto &child : children) { + if (child.IsNull()) { + throw BinderException("NULL group name in %s", function_name); + } + auto name = child.ToString(); + if (name_set.find(name) != name_set.end()) { + throw BinderException("Duplicate group name '%s' in %s", name, function_name); + } + name_set.insert(name); + out_names.push_back(name); + out_struct_children.emplace_back(make_pair(name, LogicalType::VARCHAR)); + } + if (require_constant_pattern) { + duckdb_re2::StringPiece const_piece(pattern_string.c_str(), pattern_string.size()); + RE2 constant_re(const_piece, options); + auto group_cnt = constant_re.NumberOfCapturingGroups(); + if (group_cnt == -1) { + throw BinderException("Pattern failed to parse: %s", constant_re.error()); + } + if ((idx_t)group_cnt < out_names.size()) { + throw BinderException("Not enough capturing groups (%d) for provided names (%llu)", group_cnt, + NumericCast(out_names.size())); + } + } +} + +// Advance exactly one UTF-8 codepoint starting at 'base'. Falls back to single byte on invalid lead. +// Does not do a full validation of UTF-8 sequence, assumes input is mostly valid UTF-8. +idx_t AdvanceOneUTF8Basic(const duckdb_re2::StringPiece &input, idx_t base) { + if (base >= input.length()) { + return 1; // Out of bounds, just advance one byte + } + unsigned char first = static_cast(input[base]); + idx_t char_len = 1; + if ((first & 0x80) == 0) { + char_len = 1; // ASCII + } else if ((first & 0xE0) == 0xC0) { + char_len = 2; + } else if ((first & 0xF0) == 0xE0) { + char_len = 3; + } else if ((first & 0xF8) == 0xF0) { + char_len = 4; + } else { + // This should be impossible since RE2 operates on codepoints + throw InternalException("Invalid UTF-8 lead byte in regexp_extract_all"); + } + if (base + char_len > input.length()) { + throw InternalException("Invalid UTF-8 sequence in regexp_extract_all"); + } + return char_len; +} + } // namespace regexp_util } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/string_split.cpp b/src/duckdb/src/function/scalar/string/string_split.cpp index 070886d8c..2438dfe81 100644 --- a/src/duckdb/src/function/scalar/string/string_split.cpp +++ b/src/duckdb/src/function/scalar/string/string_split.cpp @@ -181,7 +181,7 @@ ScalarFunction StringSplitFun::GetFunction() { auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); ScalarFunction string_split({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitFunction); - string_split.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + string_split.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return string_split; } diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp index b82e6871b..b2c7d3a5a 100644 --- a/src/duckdb/src/function/scalar/string/substring.cpp +++ b/src/duckdb/src/function/scalar/string/substring.cpp @@ -16,7 +16,6 @@ static const int64_t SUPPORTED_UPPER_BOUND = NumericLimits::Maximum(); static const int64_t SUPPORTED_LOWER_BOUND = -SUPPORTED_UPPER_BOUND - 1; static inline void AssertInSupportedRange(idx_t input_size, int64_t offset, int64_t length) { - if (input_size > (uint64_t)SUPPORTED_UPPER_BOUND) { throw OutOfRangeException("Substring input size is too large (> %d)", SUPPORTED_UPPER_BOUND); } @@ -309,7 +308,7 @@ unique_ptr SubstringPropagateStats(ClientContext &context, Funct // can only propagate stats if the children have stats // we only care about the stats of the first child (i.e. the string) if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = SubstringFunctionASCII; + expr.function.SetFunctionCallback(SubstringFunctionASCII); } return nullptr; } diff --git a/src/duckdb/src/function/scalar/struct/remap_struct.cpp b/src/duckdb/src/function/scalar/struct/remap_struct.cpp index 136a89165..e926a7bec 100644 --- a/src/duckdb/src/function/scalar/struct/remap_struct.cpp +++ b/src/duckdb/src/function/scalar/struct/remap_struct.cpp @@ -11,6 +11,10 @@ namespace duckdb { namespace { +static bool IsRemappable(const LogicalType &type) { + return type.IsNested() && type.id() != LogicalTypeId::VARIANT; +} + struct RemapColumnInfo { optional_idx index; optional_idx default_index; @@ -230,7 +234,7 @@ void RemapStruct(Vector &input, Vector &default_vector, Vector &result, idx_t re void RemapNested(Vector &input, Vector &default_vector, Vector &result, idx_t result_size, const vector &remap_info) { auto &source_type = input.GetType(); - D_ASSERT(source_type.IsNested()); + D_ASSERT(IsRemappable(source_type)); switch (source_type.id()) { case LogicalTypeId::STRUCT: return RemapStruct(input, default_vector, result, result_size, remap_info); @@ -293,7 +297,7 @@ struct RemapIndex { RemapIndex index; index.index = idx; index.type = type; - if (type.IsNested()) { + if (IsRemappable(type)) { index.child_map = make_uniq>(GetMap(type)); } return index; @@ -344,8 +348,8 @@ struct RemapEntry { auto &source_type = entry->second.type; auto &target_type = target_entry->second.type; - bool source_is_nested = source_type.IsNested(); - bool target_is_nested = target_type.IsNested(); + bool source_is_nested = IsRemappable(source_type); + bool target_is_nested = IsRemappable(target_type); RemapEntry remap; remap.index = entry->second.index; remap.target_type = target_entry->second.type; @@ -387,7 +391,7 @@ struct RemapEntry { remap.default_index = default_idx; if (default_type.id() == LogicalTypeId::STRUCT) { // nested remap - recurse - if (!target_type.IsNested()) { + if (!IsRemappable(target_type)) { throw BinderException("Default value is a struct - target value should be a nested type, is '%s'", target_type.ToString()); } @@ -436,7 +440,7 @@ struct RemapEntry { RemapColumnInfo info; info.index = entry->second.index; info.default_index = entry->second.default_index; - if (child_type.IsNested() && entry->second.child_remaps) { + if (IsRemappable(child_type) && entry->second.child_remaps) { // type is nested and a mapping for it is given - recurse info.child_remap_info = ConstructMap(child_type, *entry->second.child_remaps); } @@ -447,7 +451,7 @@ struct RemapEntry { static vector ConstructMap(const LogicalType &type, const case_insensitive_map_t &remap_map) { - D_ASSERT(type.IsNested()); + D_ASSERT(IsRemappable(type)); switch (type.id()) { case LogicalTypeId::STRUCT: { auto &target_children = StructType::GetChildTypes(type); @@ -484,7 +488,7 @@ struct RemapEntry { auto remap_entry = remap_map.find(entry->second); D_ASSERT(remap_entry != remap_map.end()); // this entry is remapped - fetch the target type - if (child_type.IsNested() && remap_entry->second.child_remaps) { + if (IsRemappable(child_type) && remap_entry->second.child_remaps) { // type is nested and a mapping for it is given - recurse new_source_children.emplace_back(child_name, RemapCast(child_type, *remap_entry->second.child_remaps)); @@ -552,7 +556,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction // remap target can be NULL continue; } - if (!arg->return_type.IsNested()) { + if (!IsRemappable(arg->return_type)) { throw BinderException("Struct remap can only remap nested types, not '%s'", arg->return_type.ToString()); } else if (arg->return_type.id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(arg->return_type)) { throw BinderException("Struct remap can only remap named structs"); @@ -569,7 +573,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction throw BinderException("The defaults have to be either NULL or a named STRUCT, not an unnamed struct"); } - if ((from_type.IsNested() || to_type.IsNested()) && from_type.id() != to_type.id()) { + if ((IsRemappable(from_type) || IsRemappable(to_type)) && from_type.id() != to_type.id()) { throw BinderException("Can't change source type (%s) to target type (%s), type conversion not allowed", from_type.ToString(), to_type.ToString()); } @@ -617,7 +621,7 @@ unique_ptr RemapStructBind(ClientContext &context, ScalarFunction bound_function.arguments[1] = arguments[1]->return_type; bound_function.arguments[2] = arguments[2]->return_type; bound_function.arguments[3] = arguments[3]->return_type; - bound_function.return_type = arguments[1]->return_type; + bound_function.SetReturnType(arguments[1]->return_type); return make_uniq(std::move(remap)); } @@ -628,7 +632,7 @@ ScalarFunction RemapStructFun::GetFunction() { ScalarFunction remap("remap_struct", {LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalTypeId::ANY}, LogicalTypeId::ANY, RemapStructFunction, RemapStructBind); - remap.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + remap.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return remap; } diff --git a/src/duckdb/src/function/scalar/struct/struct_concat.cpp b/src/duckdb/src/function/scalar/struct/struct_concat.cpp index ccfe7d363..153319891 100644 --- a/src/duckdb/src/function/scalar/struct/struct_concat.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_concat.cpp @@ -33,7 +33,6 @@ static void StructConcatFunction(DataChunk &args, ExpressionState &state, Vector static unique_ptr StructConcatBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // collect names and deconflict, construct return type if (arguments.empty()) { throw InvalidInputException("struct_concat: At least one argument is required"); @@ -80,7 +79,7 @@ static unique_ptr StructConcatBind(ClientContext &context, ScalarF throw InvalidInputException("struct_concat: Cannot mix named and unnamed STRUCTs"); } - bound_function.return_type = LogicalType::STRUCT(combined_children); + bound_function.SetReturnType(LogicalType::STRUCT(combined_children)); return nullptr; } @@ -108,7 +107,7 @@ ScalarFunction StructConcatFun::GetFunction() { ScalarFunction fun("struct_concat", {}, LogicalTypeId::STRUCT, StructConcatFunction, StructConcatBind, nullptr, StructConcatStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/struct/struct_contains.cpp b/src/duckdb/src/function/scalar/struct/struct_contains.cpp index 3f8b39aa9..db9fd3554 100644 --- a/src/duckdb/src/function/scalar/struct/struct_contains.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_contains.cpp @@ -204,7 +204,7 @@ static unique_ptr StructContainsBind(ClientContext &context, Scala if (child_type.id() == LogicalTypeId::SQLNULL) { bound_function.arguments[0] = LogicalTypeId::UNKNOWN; bound_function.arguments[1] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; + bound_function.SetReturnType(LogicalType::SQLNULL); return nullptr; } @@ -248,7 +248,7 @@ ScalarFunction StructContainsFun::GetFunction() { ScalarFunction StructPositionFun::GetFunction() { ScalarFunction fun("struct_contains", {LogicalTypeId::STRUCT, LogicalType::ANY}, LogicalType::INTEGER, StructSearchFunction, StructContainsBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp index 23c5419cd..5da4a265e 100644 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -83,7 +83,7 @@ static unique_ptr StructExtractBind(ClientContext &context, Scalar throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); } - bound_function.return_type = std::move(return_type); + bound_function.SetReturnType(std::move(return_type)); return StructExtractAtFun::GetBindData(key_index); } @@ -120,7 +120,7 @@ static unique_ptr StructExtractBindInternal(ClientContext &context throw BinderException("Key index %lld for struct_extract out of range - expected an index between 1 and %llu", index, struct_children.size()); } - bound_function.return_type = struct_children[NumericCast(index - 1)].second; + bound_function.SetReturnType(struct_children[NumericCast(index - 1)].second); return StructExtractAtFun::GetBindData(NumericCast(index - 1)); } diff --git a/src/duckdb/src/function/scalar/struct/struct_pack.cpp b/src/duckdb/src/function/scalar/struct/struct_pack.cpp index dfbabcca0..ff7557fe1 100644 --- a/src/duckdb/src/function/scalar/struct/struct_pack.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_pack.cpp @@ -56,8 +56,8 @@ static unique_ptr StructPackBind(ClientContext &context, ScalarFun } // this is more for completeness reasons - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type); + bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); + return make_uniq(bound_function.GetReturnType()); } static unique_ptr StructPackStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -75,9 +75,9 @@ static ScalarFunction GetStructPackFunction() { ScalarFunction fun(IS_STRUCT_PACK ? "struct_pack" : "row", {}, LogicalTypeId::STRUCT, StructPackFunction, StructPackBind, nullptr, StructPackStats); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetSerializeCallback(VariableReturnBindData::Serialize); + fun.SetDeserializeCallback(VariableReturnBindData::Deserialize); return fun; } diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index e6e8e22c0..84c51673b 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -85,7 +85,7 @@ void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector & auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); local_state.allocator.Reset(); - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); + D_ASSERT(bind_data.state_size == bind_data.aggr.GetStateSizeCallback()(bind_data.aggr)); D_ASSERT(input.data.size() == 1); D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); auto aligned_state_size = AlignValue(bind_data.state_size); @@ -105,13 +105,13 @@ void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector & } else { // create a dummy state because finalize does not understand NULLs in its input // we put the NULL back in explicitly below - bind_data.aggr.initialize(bind_data.aggr, data_ptr_cast(target_ptr)); + bind_data.aggr.GetStateInitCallback()(bind_data.aggr, data_ptr_cast(target_ptr)); } state_vec_ptr[i] = data_ptr_cast(target_ptr); } AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0); + bind_data.aggr.GetStateFinalizeCallback()(local_state.addresses, aggr_input_data, result, input.size(), 0); for (idx_t i = 0; i < input.size(); i++) { auto state_idx = state_data.sel->get_index(i); @@ -126,7 +126,7 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); local_state.allocator.Reset(); - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); + D_ASSERT(bind_data.state_size == bind_data.aggr.GetStateSizeCallback()(bind_data.aggr)); D_ASSERT(input.data.size() == 2); D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); @@ -176,7 +176,8 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); AggregateInputData aggr_input_data(nullptr, local_state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); + bind_data.aggr.GetStateCombineCallback()(local_state.state_vector0, local_state.state_vector1, aggr_input_data, + 1); result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), bind_data.state_size); @@ -185,7 +186,6 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r unique_ptr BindAggregateState(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - // grab the aggregate type and bind the aggregate again // the aggregate name and types are in the logical type of the aggregate state, make sure its sane @@ -227,7 +227,7 @@ unique_ptr BindAggregateState(ClientContext &context, ScalarFuncti error.Message()); } auto bound_aggr = aggr.functions.GetFunctionByOffset(best_function.GetIndex()); - if (bound_aggr.bind) { + if (bound_aggr.GetBindCallback()) { // FIXME: this is really hacky // but the aggregate state export needs a rework around how it handles more complex aggregates anyway vector> args; @@ -235,31 +235,32 @@ unique_ptr BindAggregateState(ClientContext &context, ScalarFuncti for (auto &arg_type : state_type.bound_argument_types) { args.push_back(make_uniq(Value(arg_type))); } - auto bind_info = bound_aggr.bind(context, bound_aggr, args); + auto bind_info = bound_aggr.GetBindCallback()(context, bound_aggr, args); if (bind_info) { throw BinderException("Aggregate function with bind info not supported yet in aggregate state export"); } } - if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { + if (bound_aggr.GetReturnType() != state_type.return_type || + bound_aggr.arguments != state_type.bound_argument_types) { throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); } if (bound_function.name == "finalize") { - bound_function.return_type = bound_aggr.return_type; + bound_function.SetReturnType(bound_aggr.GetReturnType()); } else { D_ASSERT(bound_function.name == "combine"); - bound_function.return_type = arg_return_type; + bound_function.SetReturnType(arg_return_type); } - return make_uniq(bound_aggr, bound_aggr.state_size(bound_aggr)); + return make_uniq(bound_aggr, bound_aggr.GetStateSizeCallback()(bound_aggr)); } void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { D_ASSERT(offset == 0); auto &bind_data = aggr_input_data.bind_data->Cast(); - auto state_size = bind_data.aggregate->function.state_size(bind_data.aggregate->function); + auto state_size = bind_data.aggregate->function.GetStateSizeCallback()(bind_data.aggregate->function); auto blob_ptr = FlatVector::GetData(result); auto addresses_ptr = FlatVector::GetData(state); for (idx_t row_idx = 0; row_idx < count; row_idx++) { @@ -291,39 +292,40 @@ unique_ptr ExportStateScalarDeserialize(Deserializer &deserializer unique_ptr ExportAggregateFunction::Bind(unique_ptr child_aggregate) { auto &bound_function = child_aggregate->function; - if (!bound_function.combine) { + if (!bound_function.HasStateCombineCallback()) { throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); } - if (bound_function.bind) { + if (bound_function.HasBindCallback()) { throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders"); } - if (bound_function.destructor) { + if (bound_function.HasStateDestructorCallback()) { throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); } // this should be required - D_ASSERT(bound_function.state_size); - D_ASSERT(bound_function.finalize); + D_ASSERT(bound_function.HasStateSizeCallback()); + D_ASSERT(bound_function.HasStateFinalizeCallback()); - D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); + D_ASSERT(child_aggregate->function.GetReturnType().id() != LogicalTypeId::INVALID); #ifdef DEBUG for (auto &arg_type : child_aggregate->function.arguments) { D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); } #endif auto export_bind_data = make_uniq(child_aggregate->Copy()); - aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, + aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.GetReturnType(), child_aggregate->function.arguments); auto return_type = LogicalType::AGGREGATE_STATE(std::move(state_type)); auto export_function = AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type, - bound_function.state_size, bound_function.initialize, bound_function.update, - bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, + bound_function.GetStateSizeCallback(), bound_function.GetStateInitCallback(), + bound_function.GetStateUpdateCallback(), bound_function.GetStateCombineCallback(), + ExportAggregateFinalize, bound_function.GetStateSimpleUpdateCallback(), /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, /* can't propagate statistics */ nullptr, nullptr); - export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - export_function.serialize = ExportStateAggregateSerialize; - export_function.deserialize = ExportStateAggregateDeserialize; + export_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + export_function.SetSerializeCallback(ExportStateAggregateSerialize); + export_function.SetDeserializeCallback(ExportStateAggregateDeserialize); return make_uniq(export_function, std::move(child_aggregate->children), std::move(child_aggregate->filter), std::move(export_bind_data), @@ -347,9 +349,9 @@ bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const ScalarFunction FinalizeFun::GetFunction() { auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + result.SetSerializeCallback(ExportStateScalarSerialize); + result.SetDeserializeCallback(ExportStateScalarDeserialize); return result; } @@ -357,9 +359,9 @@ ScalarFunction CombineFun::GetFunction() { auto result = ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; + result.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + result.SetSerializeCallback(ExportStateScalarSerialize); + result.SetDeserializeCallback(ExportStateScalarDeserialize); return result; } diff --git a/src/duckdb/src/function/scalar/system/parse_log_message.cpp b/src/duckdb/src/function/scalar/system/parse_log_message.cpp index d5e336165..1a81263c8 100644 --- a/src/duckdb/src/function/scalar/system/parse_log_message.cpp +++ b/src/duckdb/src/function/scalar/system/parse_log_message.cpp @@ -29,7 +29,6 @@ struct ParseLogMessageData : FunctionData { unique_ptr ParseLogMessageBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - if (arguments.size() != 2) { throw BinderException("structured_log_schema: expects 1 argument", arguments[0]->alias); } @@ -53,9 +52,10 @@ unique_ptr ParseLogMessageBind(ClientContext &context, ScalarFunct if (!lookup->is_structured) { // Unstructured types we simply wrap in a struct with a single field called message child_list_t children = {{"message", LogicalType::VARCHAR}}; - bound_function.return_type = LogicalType::STRUCT(children); + bound_function.SetReturnType(LogicalType::STRUCT(children)); } else { - bound_function.return_type = lookup->type; + D_ASSERT(lookup->type.IsNested()); + bound_function.SetReturnType(lookup->type); } return make_uniq(*lookup); @@ -77,8 +77,10 @@ void ParseLogMessageFunction(DataChunk &args, ExpressionState &state, Vector &re } // namespace ScalarFunction ParseLogMessage::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::ANY, ParseLogMessageFunction, - ParseLogMessageBind, nullptr, nullptr, nullptr, LogicalType(LogicalTypeId::INVALID)); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::ANY, ParseLogMessageFunction, + ParseLogMessageBind, nullptr, nullptr, nullptr, LogicalType(LogicalTypeId::INVALID)); + fun.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + return fun; } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/system/write_log.cpp b/src/duckdb/src/function/scalar/system/write_log.cpp index fa67ec089..6fd7aba8d 100644 --- a/src/duckdb/src/function/scalar/system/write_log.cpp +++ b/src/duckdb/src/function/scalar/system/write_log.cpp @@ -65,7 +65,7 @@ unique_ptr WriteLogBind(ClientContext &context, ScalarFunction &bo auto result = make_uniq(); // Default return type - bound_function.return_type = LogicalType::VARCHAR; + bound_function.SetReturnType(LogicalType::VARCHAR); for (idx_t i = 1; i < arguments.size(); i++) { auto &arg = arguments[i]; @@ -100,7 +100,7 @@ unique_ptr WriteLogBind(ClientContext &context, ScalarFunction &bo } else if (arg->alias == "return_value") { result->return_type = arg->return_type; result->output_col = i; - bound_function.return_type = result->return_type; + bound_function.SetReturnType(result->return_type); } else { throw BinderException(StringUtil::Format("write_log: Unknown argument '%s'", arg->alias)); } diff --git a/src/duckdb/src/function/scalar/variant/variant_extract.cpp b/src/duckdb/src/function/scalar/variant/variant_extract.cpp index e0c10fa73..118175004 100644 --- a/src/duckdb/src/function/scalar/variant/variant_extract.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_extract.cpp @@ -3,6 +3,8 @@ #include "duckdb/function/scalar/regexp.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" namespace duckdb { @@ -12,6 +14,7 @@ struct BindData : public FunctionData { public: explicit BindData(const string &str); explicit BindData(uint32_t index); + BindData(const BindData &other) = default; public: unique_ptr Copy() const override; @@ -28,15 +31,15 @@ BindData::BindData(const string &str) : FunctionData() { component.key = str; } BindData::BindData(uint32_t index) : FunctionData() { + if (index == 0) { + throw BinderException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + } component.lookup_mode = VariantChildLookupMode::BY_INDEX; - component.index = index; + component.index = index - 1; } unique_ptr BindData::Copy() const { - if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { - return make_uniq(component.index); - } - return make_uniq(component.key); + return make_uniq(*this); } bool BindData::Equals(const FunctionData &other) const { @@ -65,6 +68,64 @@ static bool GetConstantArgument(ClientContext &context, Expression &expr, Value return false; } +optional_ptr FindShreddedStats(const BaseStatistics &shredded, + const VariantPathComponent &component) { + D_ASSERT(shredded.GetType().id() == LogicalTypeId::STRUCT); + D_ASSERT(StructType::GetChildTypes(shredded.GetType()).size() == 2); + + auto &typed_value_type = StructType::GetChildTypes(shredded.GetType())[1].second; + auto &typed_value_stats = StructStats::GetChildStats(shredded, 1); + switch (component.lookup_mode) { + case VariantChildLookupMode::BY_INDEX: { + if (typed_value_type.id() != LogicalTypeId::LIST) { + return nullptr; + } + auto &child_stats = ListStats::GetChildStats(typed_value_stats); + return child_stats; + } + case VariantChildLookupMode::BY_KEY: { + if (typed_value_type.id() != LogicalTypeId::STRUCT) { + return nullptr; + } + auto &object_fields = StructType::GetChildTypes(typed_value_type); + for (idx_t i = 0; i < object_fields.size(); i++) { + auto &object_field = object_fields[i]; + if (StringUtil::CIEquals(object_field.first, component.key)) { + return StructStats::GetChildStats(typed_value_stats, i); + } + } + return nullptr; + } + default: + throw InternalException("VariantChildLookupMode::%s not implemented for FindShreddedStats", + EnumUtil::ToString(component.lookup_mode)); + } +} + +static unique_ptr VariantExtractPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &bind_data = input.bind_data; + + auto &info = bind_data->Cast(); + auto &variant_stats = child_stats[0]; + const bool is_shredded = VariantStats::IsShredded(variant_stats); + if (!is_shredded) { + return nullptr; + } + auto &shredded_stats = VariantStats::GetShreddedStats(variant_stats); + auto found_stats = FindShreddedStats(shredded_stats, info.component); + if (!found_stats) { + return nullptr; + } + + auto &unshredded_stats = VariantStats::GetUnshreddedStats(variant_stats); + auto child_variant_stats = VariantStats::CreateShredded(found_stats->GetType()); + VariantStats::SetUnshreddedStats(child_variant_stats, unshredded_stats); + VariantStats::SetShreddedStats(child_variant_stats, *found_stats); + + return child_variant_stats.ToUnique(); +} + static unique_ptr VariantExtractBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments.size() != 2) { @@ -142,22 +203,26 @@ static void VariantExtractFunction(DataChunk &input, ExpressionState &state, Vec } //! Look up the value_index of the child we're extracting - auto child_collection_result = - VariantUtils::FindChildValues(variant, component, optional_idx(), new_value_index_sel, nested_data, count); - if (!child_collection_result.Success()) { - if (child_collection_result.type == VariantChildDataCollectionResult::Type::INDEX_ZERO) { - throw InvalidInputException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, component, nullptr, new_value_index_sel, lookup_validity, nested_data, + count); + if (!lookup_validity.AllValid()) { + optional_idx index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + index = i; + break; + } } + D_ASSERT(index.IsValid()); switch (component.lookup_mode) { case VariantChildLookupMode::BY_INDEX: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); throw InvalidInputException("VARIANT(ARRAY(%d)) is missing index %d", nested_data[nested_index].child_count, component.index); } case VariantChildLookupMode::BY_KEY: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); auto row_index = nested_index; auto object_keys = VariantUtils::GetObjectKeys(variant, row_index, nested_data[nested_index]); throw InvalidInputException("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), @@ -225,10 +290,14 @@ ScalarFunctionSet VariantExtractFun::GetFunctions() { auto variant_type = LogicalType::VARIANT(); ScalarFunctionSet fun_set; - fun_set.AddFunction(ScalarFunction("variant_extract", {variant_type, LogicalType::VARCHAR}, variant_type, - VariantExtractFunction, VariantExtractBind)); - fun_set.AddFunction(ScalarFunction("variant_extract", {variant_type, LogicalType::UINTEGER}, variant_type, - VariantExtractFunction, VariantExtractBind)); + ScalarFunction variant_extract("variant_extract", {}, variant_type, VariantExtractFunction, VariantExtractBind, + nullptr, VariantExtractPropagateStats); + + variant_extract.arguments = {variant_type, LogicalType::VARCHAR}; + fun_set.AddFunction(variant_extract); + + variant_extract.arguments = {variant_type, LogicalType::UINTEGER}; + fun_set.AddFunction(variant_extract); return fun_set; } diff --git a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp new file mode 100644 index 000000000..07f2769ac --- /dev/null +++ b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp @@ -0,0 +1,268 @@ +#include "duckdb/function/scalar/regexp.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/execution/expression_executor.hpp" + +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/function/variant/variant_normalize.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" + +namespace duckdb { + +VariantNormalizerState::VariantNormalizerState(idx_t result_row, VariantVectorData &source, + OrderedOwningStringMap &dictionary, + SelectionVector &keys_selvec) + : source(source), dictionary(dictionary), keys_selvec(keys_selvec), + keys_index_validity(source.keys_index_validity) { + auto keys_list_entry = source.keys_data[result_row]; + auto values_list_entry = source.values_data[result_row]; + auto children_list_entry = source.children_data[result_row]; + + keys_offset = keys_list_entry.offset; + children_offset = children_list_entry.offset; + + blob_data = data_ptr_cast(source.blob_data[result_row].GetDataWriteable()); + type_ids = source.type_ids_data + values_list_entry.offset; + byte_offsets = source.byte_offset_data + values_list_entry.offset; + values_indexes = source.values_index_data + children_list_entry.offset; + keys_indexes = source.keys_index_data + children_list_entry.offset; +} + +data_ptr_t VariantNormalizerState::GetDestination() { + return blob_data + blob_size; +} +uint32_t VariantNormalizerState::GetOrCreateIndex(const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; +} + +void VariantNormalizer::VisitNull(VariantNormalizerState &state) { + return; +} +void VariantNormalizer::VisitBoolean(bool val, VariantNormalizerState &state) { + return; +} + +void VariantNormalizer::VisitMetadata(VariantLogicalType type_id, VariantNormalizerState &state) { + state.type_ids[state.values_size] = static_cast(type_id); + state.byte_offsets[state.values_size] = state.blob_size; + state.values_size++; +} + +void VariantNormalizer::VisitFloat(float val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitDouble(double val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitUUID(hugeint_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitDate(date_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitInterval(interval_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTime(dtime_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimeNanos(dtime_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimeTZ(dtime_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimestampSec(timestamp_sec_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimestampMs(timestamp_ms_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimestamp(timestamp_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimestampNanos(timestamp_ns_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} +void VariantNormalizer::VisitTimestampTZ(timestamp_tz_t val, VariantNormalizerState &state) { + VisitInteger(val, state); +} + +void VariantNormalizer::VisitString(const string_t &str, VariantNormalizerState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; +} +void VariantNormalizer::VisitBlob(const string_t &blob, VariantNormalizerState &state) { + return VisitString(blob, state); +} +void VariantNormalizer::VisitBignum(const string_t &bignum, VariantNormalizerState &state) { + return VisitString(bignum, state); +} +void VariantNormalizer::VisitGeometry(const string_t &geom, VariantNormalizerState &state) { + return VisitString(geom, state); +} +void VariantNormalizer::VisitBitstring(const string_t &bits, VariantNormalizerState &state) { + return VisitString(bits, state); +} + +void VariantNormalizer::VisitArray(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + idx_t result_children_idx = state.children_size; + state.blob_size += VarintEncode(result_children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto source_children_idx = nested_data.children_idx + i; + auto values_index = variant.GetValuesIndex(row, source_children_idx); + + //! Set the 'values_index' for the child, and set the 'keys_index' to NULL + state.values_indexes[result_children_idx] = state.values_size; + state.keys_index_validity.SetInvalid(state.children_offset + result_children_idx); + result_children_idx++; + + //! Visit the child value + VariantVisitor::Visit(variant, row, values_index, state); + } +} + +void VariantNormalizer::VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, VariantNormalizerState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (!nested_data.child_count) { + return; + } + uint32_t children_idx = state.children_size; + uint32_t keys_idx = state.keys_size; + state.blob_size += VarintEncode(children_idx, state.GetDestination()); + state.children_size += nested_data.child_count; + state.keys_size += nested_data.child_count; + + //! First iterate through all fields to populate the map of key -> field + map sorted_fields; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, nested_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + sorted_fields.emplace(key, i); + } + + //! Then visit the fields in sorted order + for (auto &entry : sorted_fields) { + auto source_children_idx = nested_data.children_idx + entry.second; + + //! Add the key of the field to the result + auto keys_index = variant.GetKeysIndex(row, source_children_idx); + auto &key = variant.GetKey(row, keys_index); + auto dict_index = state.GetOrCreateIndex(key); + state.keys_selvec.set_index(state.keys_offset + keys_idx, dict_index); + + //! Visit the child value + auto values_index = variant.GetValuesIndex(row, source_children_idx); + state.values_indexes[children_idx] = state.values_size; + state.keys_indexes[children_idx] = keys_idx; + children_idx++; + keys_idx++; + VariantVisitor::Visit(variant, row, values_index, state); + } +} + +void VariantNormalizer::VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantNormalizerState &state) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); +} + +void VariantNormalizer::Normalize(Vector &variant_vec, Vector &result, idx_t count) { + D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); + + //! Set up the access helper for the source VARIANT + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, source_format); + UnifiedVariantVectorData variant(source_format); + + //! Take the original sizes of the lists, the result will be similar size, never bigger + auto original_keys_size = ListVector::GetListSize(VariantVector::GetKeys(variant_vec)); + auto original_children_size = ListVector::GetListSize(VariantVector::GetChildren(variant_vec)); + auto original_values_size = ListVector::GetListSize(VariantVector::GetValues(variant_vec)); + + auto &keys = VariantVector::GetKeys(result); + auto &children = VariantVector::GetChildren(result); + auto &values = VariantVector::GetValues(result); + auto &data = VariantVector::GetData(result); + + ListVector::Reserve(keys, original_keys_size); + ListVector::SetListSize(keys, 0); + ListVector::Reserve(children, original_children_size); + ListVector::SetListSize(children, 0); + ListVector::Reserve(values, original_values_size); + ListVector::SetListSize(values, 0); + + //! Initialize the dictionary + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + + VariantVectorData variant_data(result); + SelectionVector keys_selvec; + keys_selvec.Initialize(original_keys_size); + + for (idx_t i = 0; i < count; i++) { + if (!variant.RowIsValid(i)) { + FlatVector::SetNull(result, i, true); + continue; + } + //! Allocate for the new data, use the same size as source + auto &blob_data = variant_data.blob_data[i]; + auto original_data = variant.GetData(i); + blob_data = StringVector::EmptyString(data, original_data.GetSize()); + + auto &keys_list_entry = variant_data.keys_data[i]; + keys_list_entry.offset = ListVector::GetListSize(keys); + + auto &children_list_entry = variant_data.children_data[i]; + children_list_entry.offset = ListVector::GetListSize(children); + + auto &values_list_entry = variant_data.values_data[i]; + values_list_entry.offset = ListVector::GetListSize(values); + + //! Visit the source to populate the result + VariantNormalizerState visitor_state(i, variant_data, dictionary, keys_selvec); + VariantVisitor::Visit(variant, i, 0, visitor_state); + + blob_data.SetSizeAndFinalize(visitor_state.blob_size, original_data.GetSize()); + keys_list_entry.length = visitor_state.keys_size; + children_list_entry.length = visitor_state.children_size; + values_list_entry.length = visitor_state.values_size; + + ListVector::SetListSize(keys, ListVector::GetListSize(keys) + visitor_state.keys_size); + ListVector::SetListSize(children, ListVector::GetListSize(children) + visitor_state.children_size); + ListVector::SetListSize(values, ListVector::GetListSize(values) + visitor_state.values_size); + } + + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, ListVector::GetListSize(keys)); + keys_entry.Slice(keys_selvec, ListVector::GetListSize(keys)); + + if (variant_vec.GetVectorType() == VectorType::CONSTANT_VECTOR) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } + result.Verify(count); +} + +static void VariantNormalizeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto count = input.size(); + + D_ASSERT(input.ColumnCount() == 1); + auto &variant_vec = input.data[0]; + VariantNormalizer::Normalize(variant_vec, result, count); +} + +ScalarFunction VariantNormalizeFun::GetFunction() { + auto variant_type = LogicalType::VARIANT(); + return ScalarFunction("variant_normalize", {variant_type}, variant_type, VariantNormalizeFunction); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp index d1767736e..19526a653 100644 --- a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp @@ -63,7 +63,7 @@ static void VariantTypeofFunction(DataChunk &input, ExpressionState &state, Vect ScalarFunction VariantTypeofFun::GetFunction() { auto variant_type = LogicalType::VARIANT(); auto res = ScalarFunction("variant_typeof", {variant_type}, LogicalType::VARCHAR, VariantTypeofFunction); - res.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + res.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return res; } diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index 44a370251..5160d9381 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -4,9 +4,23 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/function/variant/variant_value_convert.hpp" namespace duckdb { +PhysicalType VariantDecimalData::GetPhysicalType() const { + if (width > DecimalWidth::max) { + return PhysicalType::INT128; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT64; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT32; + } else { + return PhysicalType::INT16; + } +} + bool VariantUtils::IsNestedType(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { auto type_id = variant.GetTypeId(row, value_index); return type_id == VariantLogicalType::ARRAY || type_id == VariantLogicalType::OBJECT; @@ -19,10 +33,19 @@ VariantDecimalData VariantUtils::DecodeDecimalData(const UnifiedVariantVectorDat auto data = const_data_ptr_cast(variant.GetData(row).GetData()); auto ptr = data + byte_offset; - VariantDecimalData result; - result.width = VarintDecode(ptr); - result.scale = VarintDecode(ptr); - return result; + auto width = VarintDecode(ptr); + auto scale = VarintDecode(ptr); + auto value_ptr = ptr; + return VariantDecimalData(width, scale, value_ptr); +} + +string_t VariantUtils::DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { + auto byte_offset = variant.GetByteOffset(row, value_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto length = VarintDecode(ptr); + return string_t(reinterpret_cast(ptr), length); } VariantNestedData VariantUtils::DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, @@ -53,13 +76,11 @@ vector VariantUtils::GetObjectKeys(const UnifiedVariantVectorData &varia return object_keys; } -VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, optional_idx row, - SelectionVector &res, VariantNestedData *nested_data, - idx_t count) { - +void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count) { for (idx_t i = 0; i < count; i++) { - auto row_index = row.IsValid() ? row.GetIndex() : i; + auto row_index = sel ? sel->get_index(i) : i; auto &nested_data_entry = nested_data[i]; if (nested_data_entry.is_null) { @@ -67,13 +88,10 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { auto child_idx = component.index; - if (child_idx == 0) { - return VariantChildDataCollectionResult::IndexZero(); - } - child_idx--; if (child_idx >= nested_data_entry.child_count) { //! The list is too small to contain this index - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); + continue; } auto value_id = variant.GetValuesIndex(row_index, nested_data_entry.children_idx + child_idx); res[i] = static_cast(value_id); @@ -93,10 +111,9 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } } if (!found_child) { - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); } } - return VariantChildDataCollectionResult(); } vector VariantUtils::ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, @@ -146,133 +163,40 @@ VariantUtils::CollectNestedData(const UnifiedVariantVectorData &variant, Variant return VariantNestedDataCollectionResult(); } -Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx) { - if (!variant.RowIsValid(row)) { - return Value(LogicalTypeId::SQLNULL); - } +Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx) { + return VariantVisitor::Visit(variant, row, values_idx); +} - //! The 'values' data of the value we're currently converting - auto type_id = variant.GetTypeId(row, values_idx); - auto byte_offset = variant.GetByteOffset(row, values_idx); - - //! The blob data of the Variant, accessed by byte offset retrieved above ^ - auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); - - auto ptr = const_data_ptr_cast(blob_data + byte_offset); - switch (type_id) { - case VariantLogicalType::VARIANT_NULL: - return Value(LogicalType::SQLNULL); - case VariantLogicalType::BOOL_TRUE: - return Value::BOOLEAN(true); - case VariantLogicalType::BOOL_FALSE: - return Value::BOOLEAN(false); - case VariantLogicalType::INT8: - return Value::TINYINT(Load(ptr)); - case VariantLogicalType::INT16: - return Value::SMALLINT(Load(ptr)); - case VariantLogicalType::INT32: - return Value::INTEGER(Load(ptr)); - case VariantLogicalType::INT64: - return Value::BIGINT(Load(ptr)); - case VariantLogicalType::INT128: - return Value::HUGEINT(Load(ptr)); - case VariantLogicalType::UINT8: - return Value::UTINYINT(Load(ptr)); - case VariantLogicalType::UINT16: - return Value::USMALLINT(Load(ptr)); - case VariantLogicalType::UINT32: - return Value::UINTEGER(Load(ptr)); - case VariantLogicalType::UINT64: - return Value::UBIGINT(Load(ptr)); - case VariantLogicalType::UINT128: - return Value::UHUGEINT(Load(ptr)); - case VariantLogicalType::UUID: - return Value::UUID(Load(ptr)); - case VariantLogicalType::INTERVAL: - return Value::INTERVAL(Load(ptr)); - case VariantLogicalType::FLOAT: - return Value::FLOAT(Load(ptr)); - case VariantLogicalType::DOUBLE: - return Value::DOUBLE(Load(ptr)); - case VariantLogicalType::DATE: - return Value::DATE(date_t(Load(ptr))); - case VariantLogicalType::BLOB: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value::BLOB(const_data_ptr_cast(string_data), string_length); - } - case VariantLogicalType::VARCHAR: { - auto string_length = VarintDecode(ptr); - auto string_data = reinterpret_cast(ptr); - return Value(string_t(string_data, string_length)); - } - case VariantLogicalType::DECIMAL: { - auto width = NumericCast(VarintDecode(ptr)); - auto scale = NumericCast(VarintDecode(ptr)); - - if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else if (width > DecimalWidth::max) { - return Value::DECIMAL(Load(ptr), width, scale); - } else { - return Value::DECIMAL(Load(ptr), width, scale); - } - } - case VariantLogicalType::TIME_MICROS: - return Value::TIME(Load(ptr)); - case VariantLogicalType::TIME_MICROS_TZ: - return Value::TIMETZ(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS: - return Value::TIMESTAMP(Load(ptr)); - case VariantLogicalType::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(Load(ptr)); - case VariantLogicalType::TIMESTAMP_NANOS: - return Value::TIMESTAMPNS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MILIS: - return Value::TIMESTAMPMS(Load(ptr)); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - return Value::TIMESTAMPTZ(Load(ptr)); - case VariantLogicalType::ARRAY: { - auto count = VarintDecode(ptr); - vector array_items; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_index = variant.GetValuesIndex(row, child_index_start + i); - array_items.emplace_back(ConvertVariantToValue(variant, row, child_index)); - } +void VariantUtils::FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size) { + auto &keys = VariantVector::GetKeys(variant); + auto &keys_entry = ListVector::GetEntry(keys); + auto keys_entry_data = FlatVector::GetData(keys_entry); + + bool already_sorted = true; + + vector unsorted_to_sorted(dictionary.size()); + auto it = dictionary.begin(); + for (uint32_t sorted_idx = 0; sorted_idx < dictionary.size(); sorted_idx++) { + auto unsorted_idx = it->second; + if (unsorted_idx != sorted_idx) { + already_sorted = false; } - return Value::LIST(LogicalType::VARIANT(), std::move(array_items)); + unsorted_to_sorted[unsorted_idx] = sorted_idx; + D_ASSERT(sorted_idx < ListVector::GetListSize(keys)); + keys_entry_data[sorted_idx] = it->first; + auto size = static_cast(keys_entry_data[sorted_idx].GetSize()); + keys_entry_data[sorted_idx].SetSizeAndFinalize(size, size); + it++; } - case VariantLogicalType::OBJECT: { - auto count = VarintDecode(ptr); - child_list_t object_children; - if (count) { - auto child_index_start = VarintDecode(ptr); - for (idx_t i = 0; i < count; i++) { - auto child_value_idx = variant.GetValuesIndex(row, child_index_start + i); - auto val = ConvertVariantToValue(variant, row, child_value_idx); - - auto child_key_id = variant.GetKeysIndex(row, child_index_start + i); - auto &key = variant.GetKey(row, child_key_id); - - object_children.emplace_back(key.GetString(), std::move(val)); - } + + if (!already_sorted) { + //! Adjust the selection vector to point to the right dictionary index + for (idx_t i = 0; i < sel_size; i++) { + auto &entry = sel[i]; + auto sorted_idx = unsorted_to_sorted[entry]; + entry = sorted_idx; } - return Value::STRUCT(std::move(object_children)); - } - case VariantLogicalType::BITSTRING: { - auto string_length = VarintDecode(ptr); - return Value::BIT(ptr, string_length); - } - case VariantLogicalType::BIGNUM: { - auto string_length = VarintDecode(ptr); - return Value::BIGNUM(ptr, string_length); - } - default: - throw InternalException("VariantLogicalType(%d) not handled", static_cast(type_id)); } } diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp index d41f63b2c..f2f932768 100644 --- a/src/duckdb/src/function/table/arrow.cpp +++ b/src/duckdb/src/function/table/arrow.cpp @@ -200,11 +200,11 @@ void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunction if (global_state.CanRemoveFilterColumns()) { state.all_columns.Reset(); state.all_columns.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, data.lines_read - output_size); + ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns); output.ReferenceColumns(state.all_columns, global_state.projection_ids); } else { output.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, data.lines_read - output_size); + ArrowToDuckDB(state, data.arrow_table.GetColumns(), output); } output.Verify(); diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index e194852f0..511a272dc 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -257,7 +257,6 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, idx_t chunk_off static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, idx_t chunk_offset, ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, int64_t parent_offset) { - auto &array_info = arrow_type.GetTypeInfo(); auto array_size = array_info.FixedSize(); auto child_count = array_size * size; @@ -410,10 +409,10 @@ static void UUIDConversion(Vector &vector, const ArrowArray &array, idx_t chunk_ if (!validity_mask.RowIsValid(row)) { continue; } - tgt_ptr[row].lower = static_cast(BSwap(src_ptr[row].upper)); + tgt_ptr[row].lower = static_cast(BSwapIfLE(src_ptr[row].upper)); // flip Upper MSD - tgt_ptr[row].upper = - static_cast(static_cast(BSwap(src_ptr[row].lower)) ^ (static_cast(1) << 63)); + tgt_ptr[row].upper = static_cast(static_cast(BSwapIfLE(src_ptr[row].lower)) ^ + (static_cast(1) << 63)); } } @@ -695,7 +694,6 @@ template void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, int64_t nested_offset, uint64_t parent_offset, idx_t chunk_offset, ValidityMask &val_mask, DecimalBitWidth arrow_bit_width) { - switch (vector.GetType().InternalType()) { case PhysicalType::INT16: { auto tgt_ptr = FlatVector::GetData(vector); @@ -1184,7 +1182,6 @@ static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, i template static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - auto indices = reinterpret_cast(indices_p); for (idx_t row = 0; row < size; row++) { if (indices[row] > NumericLimits::Maximum()) { @@ -1370,8 +1367,7 @@ void ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(Vector &vector, Arro } void ArrowTableFunction::ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, - DataChunk &output, idx_t start, bool arrow_scan_is_projected, - idx_t rowid_column_index) { + DataChunk &output, bool arrow_scan_is_projected, idx_t rowid_column_index) { for (idx_t idx = 0; idx < output.ColumnCount(); idx++) { auto col_idx = scan_state.column_ids.empty() ? idx : scan_state.column_ids[idx]; diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp index 600e50f49..1ffd6e7ee 100644 --- a/src/duckdb/src/function/table/copy_csv.cpp +++ b/src/duckdb/src/function/table/copy_csv.cpp @@ -280,7 +280,31 @@ struct GlobalWriteCSVData : public GlobalFunctionData { return writer.FileSize(); } + unique_ptr GetLocalState(ClientContext &context, const idx_t flush_size) { + { + lock_guard guard(local_state_lock); + if (!local_states.empty()) { + auto result = std::move(local_states.back()); + local_states.pop_back(); + return result; + } + } + auto result = make_uniq(context, flush_size); + result->require_manual_flush = true; + return result; + } + + void StoreLocalState(unique_ptr lstate) { + lock_guard guard(local_state_lock); + lstate->Reset(); + local_states.push_back(std::move(lstate)); + } + CSVWriter writer; + +private: + mutex local_state_lock; + vector> local_states; }; static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { @@ -371,9 +395,7 @@ CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, b // Prepare Batch //===--------------------------------------------------------------------===// struct WriteCSVBatchData : public PreparedBatchData { - explicit WriteCSVBatchData(ClientContext &context, const idx_t flush_size) - : writer_local_state(make_uniq(context, flush_size)) { - writer_local_state->require_manual_flush = true; + explicit WriteCSVBatchData(unique_ptr writer_state) : writer_local_state(std::move(writer_state)) { } //! The thread-local buffer to write data into @@ -397,7 +419,8 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct auto &global_state = gstate.Cast(); // write CSV chunks to the batch data - auto batch = make_uniq(context, NextPowerOfTwo(collection->SizeInBytes())); + auto local_writer_state = global_state.GetLocalState(context, NextPowerOfTwo(collection->SizeInBytes())); + auto batch = make_uniq(std::move(local_writer_state)); for (auto &chunk : collection->Chunks()) { WriteCSVChunkInternal(global_state.writer, *batch->writer_local_state, cast_chunk, chunk, executor); } @@ -412,6 +435,7 @@ void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalF auto &csv_batch = batch.Cast(); auto &global_state = gstate.Cast(); global_state.writer.Flush(*csv_batch.writer_local_state); + global_state.StoreLocalState(std::move(csv_batch.writer_local_state)); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/function/table/direct_file_reader.cpp b/src/duckdb/src/function/table/direct_file_reader.cpp index 8aa6aba35..62bfdf7b2 100644 --- a/src/duckdb/src/function/table/direct_file_reader.cpp +++ b/src/duckdb/src/function/table/direct_file_reader.cpp @@ -1,6 +1,8 @@ #include "duckdb/function/table/direct_file_reader.hpp" + +#include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/function/table/read_file.hpp" -#include "duckdb/storage/caching_file_system.hpp" +#include "duckdb/storage/caching_file_system_wrapper.hpp" namespace duckdb { @@ -44,19 +46,20 @@ static inline void VERIFY(const string &filename, const string_t &content) { } } -void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &output) { +AsyncResult DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &output) { auto &state = global_state.Cast(); if (done || file_list_idx.GetIndex() >= state.file_list->GetTotalFileCount()) { - return; + return AsyncResult(SourceResultType::FINISHED); } auto files = state.file_list; - auto fs = CachingFileSystem::Get(context); - idx_t out_idx = 0; + + auto caching_fs = CachingFileSystemWrapper::Get(context); + const idx_t out_idx = 0; // We utilize projection pushdown here to only read the file content if the 'data' column is requested - unique_ptr file_handle = nullptr; + unique_ptr file_handle = nullptr; // Given the columns requested, do we even need to open the file? if (state.requires_file_open) { @@ -64,7 +67,15 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl if (FileSystem::IsRemoteFile(file.path)) { flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } - file_handle = fs.OpenFile(QueryContext(context), file, flags); + file_handle = caching_fs.OpenFile(file, flags); + } else { + // At least verify that the file exist + // The globbing behavior in remote filesystems can lead to files being listed that do not actually exist + if (FileSystem::IsRemoteFile(file.path) && !caching_fs.FileExists(file.path)) { + output.SetCardinality(0); + done = true; + return SourceResultType::FINISHED; + } } for (idx_t col_idx = 0; col_idx < state.column_ids.size(); col_idx++) { @@ -81,31 +92,31 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl FlatVector::GetData(file_name_vector)[out_idx] = file_name_string; } break; case ReadFileBindData::FILE_CONTENT_COLUMN: { - auto file_size_raw = file_handle->GetFileSize(); - AssertMaxFileSize(file.path, file_size_raw); - auto file_size = UnsafeNumericCast(file_size_raw); - auto &file_content_vector = output.data[col_idx]; - auto content_string = StringVector::EmptyString(file_content_vector, file_size_raw); + const auto file_size = file_handle->GetFileSize(); + AssertMaxFileSize(file.path, file_size); - auto remaining_bytes = UnsafeNumericCast(file_size); + // Initialize write stream if not yet done + if (!state.stream) { + state.stream = make_uniq(BufferAllocator::Get(context), NextPowerOfTwo(file_size)); + } + state.stream->Rewind(); - // Read in batches of 100mb - constexpr auto MAX_READ_SIZE = 100LL * 1024 * 1024; + // Read in batches of 128mb + constexpr idx_t MAX_READ_SIZE = 128LL * 1024 * 1024; + auto remaining_bytes = file_handle->IsPipe() ? MAX_READ_SIZE : file_size; while (remaining_bytes > 0) { - const auto bytes_to_read = MinValue(remaining_bytes, MAX_READ_SIZE); - const auto content_string_ptr = content_string.GetDataWriteable() + (file_size - remaining_bytes); - - idx_t actually_read; - if (file_handle->IsRemoteFile()) { - // Remote file: caching read - data_ptr_t read_ptr; - actually_read = NumericCast(bytes_to_read); - auto buffer_handle = file_handle->Read(read_ptr, actually_read); - memcpy(content_string_ptr, read_ptr, actually_read); - } else { - // Local file: non-caching read - actually_read = NumericCast(file_handle->GetFileHandle().Read( - content_string_ptr, UnsafeNumericCast(bytes_to_read))); + const auto bytes_to_read = MinValue(remaining_bytes, MAX_READ_SIZE); + state.stream->GrowCapacity(bytes_to_read); + idx_t actually_read = NumericCast( + file_handle->Read(state.stream->GetData() + state.stream->GetPosition(), bytes_to_read)); + state.stream->SetPosition(state.stream->GetPosition() + actually_read); + AssertMaxFileSize(file.path, state.stream->GetPosition()); + + if (file_handle->IsPipe()) { + if (actually_read == 0) { + remaining_bytes = 0; + } + continue; } if (actually_read == 0) { @@ -113,16 +124,17 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl throw IOException("Failed to read file '%s' at offset %lu, unexpected EOF", file.path, file_size - remaining_bytes); } - remaining_bytes -= NumericCast(actually_read); + remaining_bytes -= actually_read; } - content_string.Finalize(); + auto &file_content_vector = output.data[col_idx]; + auto &content_string = FlatVector::GetData(file_content_vector)[out_idx]; + content_string = string_t(char_ptr_cast(state.stream->GetData()), + NumericCast(state.stream->GetPosition())); if (type == LogicalType::VARCHAR) { VERIFY(file.path, content_string); } - - FlatVector::GetData(file_content_vector)[out_idx] = content_string; } break; case ReadFileBindData::FILE_SIZE_COLUMN: { auto &file_size_vector = output.data[col_idx]; @@ -134,7 +146,7 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl // This can sometimes fail (e.g. httpfs file system cant always parse the last modified time // correctly) try { - auto timestamp_seconds = file_handle->GetLastModifiedTime(); + const auto timestamp_seconds = caching_fs.GetLastModifiedTime(*file_handle); FlatVector::GetData(last_modified_vector)[out_idx] = timestamp_tz_t(timestamp_seconds); } catch (std::exception &ex) { @@ -163,6 +175,7 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl } output.SetCardinality(1); done = true; + return AsyncResult(SourceResultType::HAVE_MORE_OUTPUT); }; void DirectFileReader::FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) { diff --git a/src/duckdb/src/function/table/read_duckdb.cpp b/src/duckdb/src/function/table/read_duckdb.cpp index c68f1c32e..4f8cfac66 100644 --- a/src/duckdb/src/function/table/read_duckdb.cpp +++ b/src/duckdb/src/function/table/read_duckdb.cpp @@ -87,8 +87,8 @@ class DuckDBReader : public BaseFileReader { public: bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; shared_ptr GetUnionData(idx_t file_idx) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) override; double GetProgressInFile(ClientContext &context) override; @@ -300,14 +300,38 @@ bool DuckDBReader::TryInitializeScan(ClientContext &context, GlobalTableFunction return true; } -void DuckDBReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, LocalTableFunctionState &lstate_p, - DataChunk &chunk) { +AsyncResult DuckDBReader::Scan(ClientContext &context, GlobalTableFunctionState &gstate_p, + LocalTableFunctionState &lstate_p, DataChunk &chunk) { chunk.Reset(); auto &lstate = lstate_p.Cast(); TableFunctionInput input(bind_data.get(), lstate.local_state, global_state); - scan_function.function(context, input, chunk); - if (chunk.size() == 0) { - finished = true; + + if (!scan_function.function) { + throw InternalException("DuckDBReader works only with simple table functions"); + } else { + input.async_result = AsyncResultType::IMPLICIT; + input.results_execution_mode = AsyncResultsExecutionMode::TASK_EXECUTOR; + scan_function.function(context, input, chunk); + + switch (input.async_result.GetResultType()) { + case AsyncResultType::BLOCKED: + return std::move(input.async_result); + case AsyncResultType::HAVE_MORE_OUTPUT: + return SourceResultType::HAVE_MORE_OUTPUT; + case AsyncResultType::IMPLICIT: + if (chunk.size() > 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + finished = true; + return SourceResultType::FINISHED; + case AsyncResultType::FINISHED: + finished = true; + return SourceResultType::FINISHED; + default: + throw InternalException("DuckDBReader call of scan_function.function returned unexpected return '%'", + EnumUtil::ToChars(input.async_result.GetResultType())); + } + throw InternalException("DuckDBReader hasn't handled a scan_function.function return"); } } diff --git a/src/duckdb/src/function/table/read_file.cpp b/src/duckdb/src/function/table/read_file.cpp index d0481cc23..d929e8074 100644 --- a/src/duckdb/src/function/table/read_file.cpp +++ b/src/duckdb/src/function/table/read_file.cpp @@ -10,10 +10,43 @@ namespace duckdb { +namespace { + //------------------------------------------------------------------------------ // DirectMultiFileInfo //------------------------------------------------------------------------------ +template +struct DirectMultiFileInfo : MultiFileReaderInterface { + static unique_ptr CreateInterface(ClientContext &context); + unique_ptr InitializeOptions(ClientContext &context, + optional_ptr info) override; + bool ParseCopyOption(ClientContext &context, const string &key, const vector &values, + BaseFileReaderOptions &options, vector &expected_names, + vector &expected_types) override; + bool ParseOption(ClientContext &context, const string &key, const Value &val, MultiFileOptions &file_options, + BaseFileReaderOptions &options) override; + unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, + unique_ptr options) override; + void BindReader(ClientContext &context, vector &return_types, vector &names, + MultiFileBindData &bind_data) override; + optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, + FileExpandResult expand_result) override; + unique_ptr InitializeGlobalState(ClientContext &context, MultiFileBindData &bind_data, + MultiFileGlobalState &global_state) override; + unique_ptr InitializeLocalState(ExecutionContext &, GlobalTableFunctionState &) override; + shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, + BaseUnionData &union_data, const MultiFileBindData &bind_data_p) override; + shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, + const OpenFileInfo &file, idx_t file_idx, + const MultiFileBindData &bind_data) override; + shared_ptr CreateReader(ClientContext &context, const OpenFileInfo &file, + BaseFileReaderOptions &options, + const MultiFileOptions &file_options) override; + unique_ptr GetCardinality(const MultiFileBindData &bind_data, idx_t file_count) override; + FileGlobInput GetGlobInput() override; +}; + template unique_ptr DirectMultiFileInfo::CreateInterface(ClientContext &context) { return make_uniq(); @@ -132,14 +165,45 @@ FileGlobInput DirectMultiFileInfo::GetGlobInput() { } //------------------------------------------------------------------------------ -// Register +// Operations //------------------------------------------------------------------------------ + +struct ReadBlobOperation { + static constexpr const char *NAME = "read_blob"; + static constexpr const char *FILE_TYPE = "blob"; + + static inline LogicalType TYPE() { + return LogicalType::BLOB; + } +}; + +struct ReadTextOperation { + static constexpr const char *NAME = "read_text"; + static constexpr const char *FILE_TYPE = "text"; + + static inline LogicalType TYPE() { + return LogicalType::VARCHAR; + } +}; + template static TableFunction GetFunction() { MultiFileFunction> table_function(OP::NAME); + // Erase extra multi file reader options + table_function.named_parameters.erase("filename"); + table_function.named_parameters.erase("hive_partitioning"); + table_function.named_parameters.erase("union_by_name"); + table_function.named_parameters.erase("hive_types"); + table_function.named_parameters.erase("hive_types_autocast"); return table_function; } +} // namespace + +//------------------------------------------------------------------------------ +// Register +//------------------------------------------------------------------------------ + void ReadBlobFunction::RegisterFunction(BuiltinFunctions &set) { auto scan_fun = GetFunction(); set.AddFunction(MultiFileReader::CreateFunctionSet(scan_fun)); diff --git a/src/duckdb/src/function/table/summary.cpp b/src/duckdb/src/function/table/summary.cpp index d6c4615e4..8c12148ca 100644 --- a/src/duckdb/src/function/table/summary.cpp +++ b/src/duckdb/src/function/table/summary.cpp @@ -9,7 +9,6 @@ namespace duckdb { static unique_ptr SummaryFunctionBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::VARCHAR); names.emplace_back("summary"); diff --git a/src/duckdb/src/function/table/system/duckdb_columns.cpp b/src/duckdb/src/function/table/system/duckdb_columns.cpp index fe958ea3f..ff14fdd73 100644 --- a/src/duckdb/src/function/table/system/duckdb_columns.cpp +++ b/src/duckdb/src/function/table/system/duckdb_columns.cpp @@ -196,7 +196,8 @@ unique_ptr ColumnHelper::Create(CatalogEntry &entry) { case CatalogType::VIEW_ENTRY: return make_uniq(entry.Cast()); default: - throw NotImplementedException("Unsupported catalog type for duckdb_columns"); + throw NotImplementedException({{"catalog_type", CatalogTypeToString(entry.type)}}, + "Unsupported catalog type for duckdb_columns"); } } diff --git a/src/duckdb/src/function/table/system/duckdb_connection_count.cpp b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp new file mode 100644 index 000000000..ce7857f3b --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp @@ -0,0 +1,45 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/connection_manager.hpp" + +namespace duckdb { + +struct DuckDBConnectionCountData : public GlobalTableFunctionState { + DuckDBConnectionCountData() : count(0), finished(false) { + } + idx_t count; + bool finished; +}; + +static unique_ptr DuckDBConnectionCountBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("count"); + return_types.emplace_back(LogicalType::UBIGINT); + return nullptr; +} + +unique_ptr DuckDBConnectionCountInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + auto &conn_manager = context.db->GetConnectionManager(); + result->count = conn_manager.GetConnectionCount(); + return std::move(result); +} + +void DuckDBConnectionCountFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + return; + } + output.SetValue(0, 0, Value::UBIGINT(data.count)); + output.SetCardinality(1); + data.finished = true; +} + +void DuckDBConnectionCountFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_connection_count", {}, DuckDBConnectionCountFunction, + DuckDBConnectionCountBind, DuckDBConnectionCountInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_extensions.cpp b/src/duckdb/src/function/table/system/duckdb_extensions.cpp index 6a528c111..f467a6401 100644 --- a/src/duckdb/src/function/table/system/duckdb_extensions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_extensions.cpp @@ -97,44 +97,46 @@ unique_ptr DuckDBExtensionsInit(ClientContext &context // Secondly we scan all installed extensions and their install info #ifndef WASM_LOADABLE_EXTENSIONS - auto ext_directory = ExtensionHelper::GetExtensionDirectoryPath(context); - fs.ListFiles(ext_directory, [&](const string &path, bool is_directory) { - if (!StringUtil::EndsWith(path, ".duckdb_extension")) { - return; - } - ExtensionInformation info; - info.name = fs.ExtractBaseName(path); - info.installed = true; - info.loaded = false; - info.file_path = fs.JoinPath(ext_directory, path); - - // Check the info file for its installation source - auto info_file_path = fs.JoinPath(ext_directory, path + ".info"); - - // Read the info file - auto extension_install_info = ExtensionInstallInfo::TryReadInfoFile(fs, info_file_path, info.name); - info.install_mode = extension_install_info->mode; - info.extension_version = extension_install_info->version; - if (extension_install_info->mode == ExtensionInstallMode::REPOSITORY) { - info.installed_from = ExtensionRepository::GetRepository(extension_install_info->repository_url); - } else { - info.installed_from = extension_install_info->full_path; - } + auto ext_directories = ExtensionHelper::GetExtensionDirectoryPath(context); + for (const auto &ext_directory : ext_directories) { + fs.ListFiles(ext_directory, [&](const string &path, bool is_directory) { + if (!StringUtil::EndsWith(path, ".duckdb_extension")) { + return; + } + ExtensionInformation info; + info.name = fs.ExtractBaseName(path); + info.installed = true; + info.loaded = false; + info.file_path = fs.JoinPath(ext_directory, path); + + // Check the info file for its installation source + auto info_file_path = fs.JoinPath(ext_directory, path + ".info"); + + // Read the info file + auto extension_install_info = ExtensionInstallInfo::TryReadInfoFile(fs, info_file_path, info.name); + info.install_mode = extension_install_info->mode; + info.extension_version = extension_install_info->version; + if (extension_install_info->mode == ExtensionInstallMode::REPOSITORY) { + info.installed_from = ExtensionRepository::GetRepository(extension_install_info->repository_url); + } else { + info.installed_from = extension_install_info->full_path; + } - auto entry = installed_extensions.find(info.name); - if (entry == installed_extensions.end()) { - installed_extensions[info.name] = std::move(info); - } else { - if (entry->second.install_mode != ExtensionInstallMode::STATICALLY_LINKED) { - entry->second.file_path = info.file_path; - entry->second.install_mode = info.install_mode; - entry->second.installed_from = info.installed_from; - entry->second.install_mode = info.install_mode; - entry->second.extension_version = info.extension_version; + auto entry = installed_extensions.find(info.name); + if (entry == installed_extensions.end()) { + installed_extensions[info.name] = std::move(info); + } else { + if (entry->second.install_mode != ExtensionInstallMode::STATICALLY_LINKED) { + entry->second.file_path = info.file_path; + entry->second.install_mode = info.install_mode; + entry->second.installed_from = info.installed_from; + entry->second.install_mode = info.install_mode; + entry->second.extension_version = info.extension_version; + } + entry->second.installed = true; } - entry->second.installed = true; - } - }); + }); + } #endif // Finally, we check the list of currently loaded extensions diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp index b0c7656fe..09ce83bcd 100644 --- a/src/duckdb/src/function/table/system/duckdb_functions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_functions.cpp @@ -15,14 +15,20 @@ #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/types.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/function/scalar_function.hpp" namespace duckdb { +constexpr const char *AggregateFunctionCatalogEntry::Name; struct DuckDBFunctionsData : public GlobalTableFunctionState { - DuckDBFunctionsData() : offset(0), offset_in_entry(0) { + DuckDBFunctionsData() : window_iterator(WindowExpression::WindowFunctions()), offset(0), offset_in_entry(0) { } vector> entries; + const WindowFunctionDefinition *window_iterator; idx_t offset; idx_t offset_in_entry; }; @@ -141,7 +147,7 @@ struct ScalarFunctionExtractor { } static Value GetReturnType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + return Value(entry.functions.GetFunctionByOffset(offset).GetReturnType().ToString()); } static vector GetParameters(ScalarFunctionCatalogEntry &entry, idx_t offset) { @@ -176,11 +182,84 @@ struct ScalarFunctionExtractor { } static Value IsVolatile(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).GetStability() == + FunctionStability::VOLATILE); } static Value ResultType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); + return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).GetStability()); + } +}; + +namespace { + +struct WindowFunctionCatalogEntry : CatalogEntry { +public: + WindowFunctionCatalogEntry(const SchemaCatalogEntry &schema, const string &name, vector arguments, + LogicalType return_type) + : CatalogEntry(CatalogType::AGGREGATE_FUNCTION_ENTRY, name, 0), schema(schema), arguments(std::move(arguments)), + return_type(std::move(return_type)) { + internal = true; + } + +public: + const SchemaCatalogEntry &schema; + vector arguments; + LogicalType return_type; + vector descriptions; + string alias_of; +}; + +} // namespace + +struct WindowFunctionExtractor { + static idx_t FunctionCount(WindowFunctionCatalogEntry &entry) { + return 1; + } + + static Value GetFunctionType() { + //! FIXME: should be 'window' but requires adapting generation scripts + return Value("aggregate"); + } + + static Value GetReturnType(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(entry.return_type.ToString()); + } + + static vector GetParameters(WindowFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.arguments.size(); i++) { + results.emplace_back("col" + to_string(i)); + } + return results; + } + + static Value GetParameterTypes(WindowFunctionCatalogEntry &entry, idx_t offset) { + vector results; + for (idx_t i = 0; i < entry.arguments.size(); i++) { + results.emplace_back(entry.arguments[i].ToString()); + } + return Value::LIST(LogicalType::VARCHAR, std::move(results)); + } + + static vector GetParameterLogicalTypes(WindowFunctionCatalogEntry &entry, idx_t offset) { + return entry.arguments; + } + + static Value GetVarArgs(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value GetMacroDefinition(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value(); + } + + static Value IsVolatile(WindowFunctionCatalogEntry &entry, idx_t offset) { + return Value::BOOLEAN(false); + } + + static Value ResultType(WindowFunctionCatalogEntry &entry, idx_t offset) { + return FunctionStabilityToValue(FunctionStability::CONSISTENT); } }; @@ -194,7 +273,7 @@ struct AggregateFunctionExtractor { } static Value GetReturnType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); + return Value(entry.functions.GetFunctionByOffset(offset).GetReturnType().ToString()); } static vector GetParameters(AggregateFunctionCatalogEntry &entry, idx_t offset) { @@ -229,11 +308,12 @@ struct AggregateFunctionExtractor { } static Value IsVolatile(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); + return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).GetStability() == + FunctionStability::VOLATILE); } static Value ResultType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); + return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).GetStability()); } }; @@ -497,7 +577,7 @@ static vector ToValueVector(vector &string_vector) { } template -static Value GetParameterNames(FunctionEntry &entry, idx_t function_idx, FunctionDescription &function_description, +static Value GetParameterNames(CatalogEntry &entry, idx_t function_idx, FunctionDescription &function_description, Value ¶meter_types) { vector parameter_names; if (!function_description.parameter_names.empty()) { @@ -566,13 +646,13 @@ static optional_idx GetFunctionDescriptionIndex(vector &fun } template -bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { +bool ExtractFunctionData(CatalogEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { auto &function = entry.Cast(); vector parameter_types_vector = OP::GetParameterLogicalTypes(function, function_idx); Value parameter_types_value = OP::GetParameterTypes(function, function_idx); - optional_idx description_idx = GetFunctionDescriptionIndex(entry.descriptions, parameter_types_vector); + optional_idx description_idx = GetFunctionDescriptionIndex(function.descriptions, parameter_types_vector); FunctionDescription function_description = - description_idx.IsValid() ? entry.descriptions[description_idx.GetIndex()] : FunctionDescription(); + description_idx.IsValid() ? function.descriptions[description_idx.GetIndex()] : FunctionDescription(); idx_t col = 0; @@ -601,10 +681,10 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou (function_description.description.empty()) ? Value() : Value(function_description.description)); // comment, LogicalType::VARCHAR - output.SetValue(col++, output_offset, entry.comment); + output.SetValue(col++, output_offset, function.comment); // tags, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR) - output.SetValue(col++, output_offset, Value::MAP(entry.tags)); + output.SetValue(col++, output_offset, Value::MAP(function.tags)); // return_type, LogicalType::VARCHAR output.SetValue(col++, output_offset, OP::GetReturnType(function, function_idx)); @@ -645,9 +725,75 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou return function_idx + 1 == OP::FunctionCount(function); } +void ExtractWindowFunctionData(ClientContext &context, const WindowFunctionDefinition *it, DataChunk &output, + idx_t output_offset) { + D_ASSERT(it && it->name != nullptr); + string name(it->name); + + auto &system_catalog = Catalog::GetSystemCatalog(DatabaseInstance::GetDatabase(context)); + string schema_name(DEFAULT_SCHEMA); + EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema_name); + auto &default_schema = system_catalog.GetSchema(context, schema_lookup); + + switch (it->expression_type) { + case ExpressionType::WINDOW_FILL: + case ExpressionType::WINDOW_LAST_VALUE: + case ExpressionType::WINDOW_FIRST_VALUE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::TEMPLATE("T")}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_NTH_VALUE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::TEMPLATE("T"), LogicalType::BIGINT}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_ROW_NUMBER: + case ExpressionType::WINDOW_RANK: + case ExpressionType::WINDOW_RANK_DENSE: { + WindowFunctionCatalogEntry function(default_schema, name, {}, LogicalType::BIGINT); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_NTILE: { + WindowFunctionCatalogEntry function(default_schema, name, {LogicalType::BIGINT}, LogicalType::BIGINT); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_PERCENT_RANK: + case ExpressionType::WINDOW_CUME_DIST: { + WindowFunctionCatalogEntry function(default_schema, name, {}, LogicalType::DOUBLE); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + case ExpressionType::WINDOW_LAG: + case ExpressionType::WINDOW_LEAD: { + WindowFunctionCatalogEntry function( + default_schema, name, {LogicalType::TEMPLATE("T"), LogicalType::BIGINT, LogicalType::TEMPLATE("T")}, + LogicalType::TEMPLATE("T")); + ExtractFunctionData(function, 0, output, output_offset); + break; + } + default: + throw InternalException("Window function '%s' not implemented", name); + } +} + +static bool Finished(const DuckDBFunctionsData &data) { + if (data.offset < data.entries.size()) { + return false; + } + if (data.window_iterator->name == nullptr) { + return true; + } + return false; +} + void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { + if (Finished(data)) { // finished returning values return; } @@ -696,6 +842,11 @@ void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, } count++; } + while (data.window_iterator->name != nullptr && count < STANDARD_VECTOR_SIZE) { + ExtractWindowFunctionData(context, data.window_iterator, output, count); + count++; + data.window_iterator++; + } output.SetCardinality(count); } diff --git a/src/duckdb/src/function/table/system/duckdb_log.cpp b/src/duckdb/src/function/table/system/duckdb_log.cpp index f84cb405a..96c35853f 100644 --- a/src/duckdb/src/function/table/system/duckdb_log.cpp +++ b/src/duckdb/src/function/table/system/duckdb_log.cpp @@ -62,6 +62,9 @@ unique_ptr DuckDBLogBindReplace(ClientContext &context, TableFunctionB bool denormalized_table = false; auto denormalized_table_setting = input.named_parameters.find("denormalized_table"); if (denormalized_table_setting != input.named_parameters.end()) { + if (denormalized_table_setting->second.IsNull()) { + throw InvalidInputException("denormalized_table cannot be NULL"); + } denormalized_table = denormalized_table_setting->second.GetValue(); } diff --git a/src/duckdb/src/function/table/system/duckdb_secrets.cpp b/src/duckdb/src/function/table/system/duckdb_secrets.cpp index 6069344bf..ae7f3104a 100644 --- a/src/duckdb/src/function/table/system/duckdb_secrets.cpp +++ b/src/duckdb/src/function/table/system/duckdb_secrets.cpp @@ -37,6 +37,9 @@ static unique_ptr DuckDBSecretsBind(ClientContext &context, TableF auto entry = input.named_parameters.find("redact"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for redact"); + } if (BooleanValue::Get(entry->second)) { result->redact = SecretDisplayType::REDACTED; } else { diff --git a/src/duckdb/src/function/table/system/logging_utils.cpp b/src/duckdb/src/function/table/system/logging_utils.cpp index 0cc6021ae..8a4d689eb 100644 --- a/src/duckdb/src/function/table/system/logging_utils.cpp +++ b/src/duckdb/src/function/table/system/logging_utils.cpp @@ -23,6 +23,10 @@ class EnableLoggingBindData : public TableFunctionData { static void EnableLogging(ClientContext &context, TableFunctionInput &data, DataChunk &output) { auto bind_data = data.bind_data->Cast(); + DUCKDB_LOG_WARNING(context, "The logging settings have been changed so you may lose warnings printed in the CLI.\n" + "To continue printing warnings to the console, set storage='shell_log_storage'.\n" + "For more info see https://duckdb.org/docs/stable/operations_manual/logging/overview.") + auto &log_manager = context.db->GetLogManager(); // Apply the config generated from the input diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp index 5500c1c5d..20ed3260e 100644 --- a/src/duckdb/src/function/table/system/pragma_storage_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_storage_info.cpp @@ -12,6 +12,7 @@ #include "duckdb/storage/data_table.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/planner/binder.hpp" +#include "duckdb/storage/table/column_data.hpp" #include @@ -88,7 +89,7 @@ static unique_ptr PragmaStorageInfoBind(ClientContext &context, Ta Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(); + result->column_segments_info = table_entry.GetColumnSegmentInfo(context); return std::move(result); } @@ -155,6 +156,7 @@ static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput } else { output.SetValue(col_idx++, count, Value()); } + count++; } output.SetCardinality(count); diff --git a/src/duckdb/src/function/table/system/pragma_table_sample.cpp b/src/duckdb/src/function/table/system/pragma_table_sample.cpp index ce083d92c..cf5a9ccfb 100644 --- a/src/duckdb/src/function/table/system/pragma_table_sample.cpp +++ b/src/duckdb/src/function/table/system/pragma_table_sample.cpp @@ -32,7 +32,6 @@ struct DuckDBTableSampleOperatorData : public GlobalTableFunctionState { static unique_ptr DuckDBTableSampleBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - // look up the table name in the catalog auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); diff --git a/src/duckdb/src/function/table/system/pragma_user_agent.cpp b/src/duckdb/src/function/table/system/pragma_user_agent.cpp index 3803f7195..6448422bf 100644 --- a/src/duckdb/src/function/table/system/pragma_user_agent.cpp +++ b/src/duckdb/src/function/table/system/pragma_user_agent.cpp @@ -13,7 +13,6 @@ struct PragmaUserAgentData : public GlobalTableFunctionState { static unique_ptr PragmaUserAgentBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - names.emplace_back("user_agent"); return_types.emplace_back(LogicalType::VARCHAR); diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index cd4ba3964..8460ce9a9 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -19,9 +19,10 @@ struct TestAllTypesData : public GlobalTableFunctionState { idx_t offset; }; -vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_large_bignum) { +vector TestAllTypesFun::GetTestTypes(const bool use_large_enum, const bool use_large_bignum) { vector result; - // scalar types/numerics + + // Numeric types. result.emplace_back(LogicalType::BOOLEAN, "bool"); result.emplace_back(LogicalType::TINYINT, "tinyint"); result.emplace_back(LogicalType::SMALLINT, "smallint"); @@ -33,24 +34,31 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_lar result.emplace_back(LogicalType::USMALLINT, "usmallint"); result.emplace_back(LogicalType::UINTEGER, "uint"); result.emplace_back(LogicalType::UBIGINT, "ubigint"); + + // BIGNUM. if (use_large_bignum) { string data; - idx_t total_data_size = Bignum::BIGNUM_HEADER_SIZE + Bignum::MAX_DATA_SIZE; + constexpr idx_t total_data_size = Bignum::BIGNUM_HEADER_SIZE + Bignum::MAX_DATA_SIZE; data.resize(total_data_size); - // Let's set our header + + // Let's set the max header. Bignum::SetHeader(&data[0], Bignum::MAX_DATA_SIZE, false); - // Set all our other bits + // Set all other max bits. memset(&data[Bignum::BIGNUM_HEADER_SIZE], 0xFF, Bignum::MAX_DATA_SIZE); auto max = Value::BIGNUM(data); - // Let's set our header + + // Let's set the min header. Bignum::SetHeader(&data[0], Bignum::MAX_DATA_SIZE, true); - // Set all our other bits + // Set all other min bits. memset(&data[Bignum::BIGNUM_HEADER_SIZE], 0x00, Bignum::MAX_DATA_SIZE); auto min = Value::BIGNUM(data); result.emplace_back(LogicalType::BIGNUM, "bignum", min, max); + } else { result.emplace_back(LogicalType::BIGNUM, "bignum"); } + + // Time-types. result.emplace_back(LogicalType::DATE, "date"); result.emplace_back(LogicalType::TIME, "time"); result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); @@ -59,15 +67,19 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_lar result.emplace_back(LogicalType::TIMESTAMP_NS, "timestamp_ns"); result.emplace_back(LogicalType::TIME_TZ, "time_tz"); result.emplace_back(LogicalType::TIMESTAMP_TZ, "timestamp_tz"); - result.emplace_back(LogicalType::FLOAT, "float"); - result.emplace_back(LogicalType::DOUBLE, "double"); + + // More complex numeric types. + result.emplace_back(LogicalType::FLOAT, "float", Value::FLOAT(std::numeric_limits::lowest()), + Value::FLOAT(std::numeric_limits::max())); + result.emplace_back(LogicalType::DOUBLE, "double", Value::DOUBLE(std::numeric_limits::lowest()), + Value::DOUBLE(std::numeric_limits::max())); result.emplace_back(LogicalType::DECIMAL(4, 1), "dec_4_1"); result.emplace_back(LogicalType::DECIMAL(9, 4), "dec_9_4"); result.emplace_back(LogicalType::DECIMAL(18, 6), "dec_18_6"); result.emplace_back(LogicalType::DECIMAL(38, 10), "dec38_10"); result.emplace_back(LogicalType::UUID, "uuid"); - // interval + // Interval. interval_t min_interval; min_interval.months = 0; min_interval.days = 0; @@ -79,14 +91,15 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_lar max_interval.micros = 999999999; result.emplace_back(LogicalType::INTERVAL, "interval", Value::INTERVAL(min_interval), Value::INTERVAL(max_interval)); - // strings/blobs/bitstrings + + // VARCHAR / BLOB / Bitstrings. result.emplace_back(LogicalType::VARCHAR, "varchar", Value("🦆🦆🦆🦆🦆🦆"), Value(string("goo\x00se", 6))); result.emplace_back(LogicalType::BLOB, "blob", Value::BLOB("thisisalongblob\\x00withnullbytes"), Value::BLOB("\\x00\\x00\\x00a")); result.emplace_back(LogicalType::BIT, "bit", Value::BIT("0010001001011100010101011010111"), Value::BIT("10101")); - // enums + // ENUMs. Vector small_enum(LogicalType::VARCHAR, 2); auto small_enum_ptr = FlatVector::GetData(small_enum); small_enum_ptr[0] = StringVector::AddStringOrBlob(small_enum, "DUCK_DUCK_ENUM"); @@ -116,7 +129,7 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_lar result.emplace_back(LogicalType::ENUM(large_enum, 2), "large_enum"); } - // arrays + // ARRAYs. auto int_list_type = LogicalType::LIST(LogicalType::INTEGER); auto empty_int_list = Value::LIST(LogicalType::INTEGER, vector()); auto int_list = @@ -319,10 +332,16 @@ static unique_ptr TestAllTypesBind(ClientContext &context, TableFu bool use_large_bignum = false; auto entry = input.named_parameters.find("use_large_enum"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for use_large_enum"); + } use_large_enum = BooleanValue::Get(entry->second); } entry = input.named_parameters.find("use_large_bignum"); if (entry != input.named_parameters.end()) { + if (entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for use_large_bignum"); + } use_large_bignum = BooleanValue::Get(entry->second); } result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum, use_large_bignum); diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp index 23dab8758..5c5c073be 100644 --- a/src/duckdb/src/function/table/system/test_vector_types.cpp +++ b/src/duckdb/src/function/table/system/test_vector_types.cpp @@ -277,6 +277,9 @@ static unique_ptr TestVectorTypesBind(ClientContext &context, Tabl } for (auto &entry : input.named_parameters) { if (entry.first == "all_flat") { + if (entry.second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for all_flat"); + } result->all_flat = BooleanValue::Get(entry.second); } else { throw InternalException("Unrecognized named parameter for test_vector_types"); diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index d10ec5d31..0a6a03507 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -18,6 +18,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { PragmaDatabaseSize::RegisterFunction(*this); PragmaUserAgent::RegisterFunction(*this); + DuckDBConnectionCountFun::RegisterFunction(*this); DuckDBApproxDatabaseCountFun::RegisterFunction(*this); DuckDBColumnsFun::RegisterFunction(*this); DuckDBConstraintsFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index 99a9bcf79..99788a1ae 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -24,6 +24,7 @@ #include "duckdb/main/client_data.hpp" #include "duckdb/common/algorithm.hpp" #include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include "duckdb/planner/filter/in_filter.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" @@ -54,6 +55,7 @@ struct IndexScanLocalState : public LocalTableFunctionState { TableScanState scan_state; //! The column IDs of the local storage scan. vector column_ids; + bool in_charge_of_final_stretch {false}; }; static StorageIndex TransformStorageIndex(const ColumnIndex &column_id) { @@ -114,7 +116,7 @@ class DuckIndexScanState : public TableScanGlobalState { public: DuckIndexScanState(ClientContext &context, const FunctionData *bind_data_p) : TableScanGlobalState(context, bind_data_p), next_batch_index(0), arena(Allocator::Get(context)), - row_ids(nullptr), row_id_count(0), finished(false) { + row_ids(nullptr), row_id_count(0), finished_first_phase(false), started_last_phase(false) { } //! The batch index of the next Sink. @@ -129,7 +131,8 @@ class DuckIndexScanState : public TableScanGlobalState { //! The column IDs of the to-be-scanned columns. vector column_ids; //! True, if no more row IDs must be scanned. - bool finished; + bool finished_first_phase; + bool started_last_phase; //! Synchronize changes to the global index scan state. mutex index_scan_lock; @@ -163,44 +166,81 @@ class DuckIndexScanState : public TableScanGlobalState { auto &storage = duck_table.GetStorage(); auto &l_state = data_p.local_state->Cast(); - idx_t scan_count = 0; - idx_t offset = 0; - - { - // Synchronize changes to the shared global state. - lock_guard l(index_scan_lock); - if (!finished) { - l_state.batch_index = next_batch_index; - next_batch_index++; - - offset = l_state.batch_index * STANDARD_VECTOR_SIZE; - auto remaining = row_id_count - offset; - scan_count = remaining < STANDARD_VECTOR_SIZE ? remaining : STANDARD_VECTOR_SIZE; - finished = remaining < STANDARD_VECTOR_SIZE ? true : false; + enum class ExecutionPhase { NONE = 0, STORAGE = 1, LOCAL_STORAGE = 2 }; + + // We might need to loop back, so while (true) + while (true) { + idx_t scan_count = 0; + idx_t offset = 0; + + // Phase selection + auto phase_to_be_performed = ExecutionPhase::NONE; + { + // Synchronize changes to the shared global state. + lock_guard l(index_scan_lock); + if (!finished_first_phase) { + l_state.batch_index = next_batch_index; + next_batch_index++; + + offset = l_state.batch_index * STANDARD_VECTOR_SIZE; + auto remaining = row_id_count - offset; + scan_count = remaining <= STANDARD_VECTOR_SIZE ? remaining : STANDARD_VECTOR_SIZE; + finished_first_phase = remaining <= STANDARD_VECTOR_SIZE ? true : false; + phase_to_be_performed = ExecutionPhase::STORAGE; + } else if (!started_last_phase) { + // First thread to get last phase, great, set l_state's in_charge_of_final_stretch, so same thread + // will be on again + started_last_phase = true; + l_state.in_charge_of_final_stretch = true; + phase_to_be_performed = ExecutionPhase::LOCAL_STORAGE; + } else if (l_state.in_charge_of_final_stretch) { + phase_to_be_performed = ExecutionPhase::LOCAL_STORAGE; + } } - } - if (scan_count != 0) { - auto row_id_data = reinterpret_cast(row_ids + offset); - Vector local_vector(LogicalType::ROW_TYPE, row_id_data); - - if (CanRemoveFilterColumns()) { - l_state.all_columns.Reset(); - storage.Fetch(tx, l_state.all_columns, column_ids, local_vector, scan_count, l_state.fetch_state); - output.ReferenceColumns(l_state.all_columns, projection_ids); - } else { - storage.Fetch(tx, output, column_ids, local_vector, scan_count, l_state.fetch_state); + switch (phase_to_be_performed) { + case ExecutionPhase::NONE: { + // No work to be picked up + return; + } + case ExecutionPhase::STORAGE: { + // Scan (in parallel) storage + auto row_id_data = reinterpret_cast(row_ids + offset); + Vector local_vector(LogicalType::ROW_TYPE, row_id_data); + + if (CanRemoveFilterColumns()) { + l_state.all_columns.Reset(); + storage.Fetch(tx, l_state.all_columns, column_ids, local_vector, scan_count, l_state.fetch_state); + output.ReferenceColumns(l_state.all_columns, projection_ids); + } else { + storage.Fetch(tx, output, column_ids, local_vector, scan_count, l_state.fetch_state); + } + if (output.size() == 0) { + if (data_p.results_execution_mode == AsyncResultsExecutionMode::TASK_EXECUTOR) { + // We can avoid looping, and just return as appropriate + data_p.async_result = AsyncResultType::HAVE_MORE_OUTPUT; + return; + } + + // output is empty, loop back, since there might be results to be picked up from LOCAL_STORAGE phase + continue; + } + return; + } + case ExecutionPhase::LOCAL_STORAGE: { + // Scan (sequentially, always same logical thread) local_storage + auto &local_storage = LocalStorage::Get(tx); + { + if (CanRemoveFilterColumns()) { + l_state.all_columns.Reset(); + local_storage.Scan(l_state.scan_state.local_state, column_ids, l_state.all_columns); + output.ReferenceColumns(l_state.all_columns, projection_ids); + } else { + local_storage.Scan(l_state.scan_state.local_state, column_ids, output); + } + } + return; } - } - - if (output.size() == 0) { - auto &local_storage = LocalStorage::Get(tx); - if (CanRemoveFilterColumns()) { - l_state.all_columns.Reset(); - local_storage.Scan(l_state.scan_state.local_state, column_ids, l_state.all_columns); - output.ReferenceColumns(l_state.all_columns, projection_ids); - } else { - local_storage.Scan(l_state.scan_state.local_state, column_ids, output); } } } @@ -249,6 +289,11 @@ class DuckTableScanState : public TableScanGlobalState { storage_ids.push_back(GetStorageIndex(bind_data.table, col)); } + if (bind_data.order_options) { + l_state->scan_state.table_state.reorderer = make_uniq(*bind_data.order_options); + l_state->scan_state.local_state.reorderer = make_uniq(*bind_data.order_options); + } + l_state->scan_state.Initialize(std::move(storage_ids), context.client, input.filters, input.sample_options); storage.NextParallelScan(context.client, state, l_state->scan_state); @@ -265,9 +310,6 @@ class DuckTableScanState : public TableScanGlobalState { l_state.scan_state.options.force_fetch_row = ClientConfig::GetConfig(context).force_fetch_row; do { - if (context.interrupted) { - throw InterruptException(); - } if (bind_data.is_create_index) { storage.CreateIndexScan(l_state.scan_state, output, TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); @@ -283,9 +325,23 @@ class DuckTableScanState : public TableScanGlobalState { } auto next = storage.NextParallelScan(context, state, l_state.scan_state); + if (data_p.results_execution_mode == AsyncResultsExecutionMode::TASK_EXECUTOR) { + // We can avoid looping, and just return as appropriate + if (!next) { + data_p.async_result = AsyncResultType::FINISHED; + } else { + data_p.async_result = AsyncResultType::HAVE_MORE_OUTPUT; + } + return; + } if (!next) { return; } + + // Before looping back, check if we are interrupted + if (context.interrupted) { + throw InterruptException(); + } } while (true); } @@ -329,6 +385,11 @@ static unique_ptr TableScanInitLocal(ExecutionContext & unique_ptr DuckTableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, DataTable &storage, const TableScanBindData &bind_data) { auto g_state = make_uniq(context, input.bind_data.get()); + if (bind_data.order_options) { + g_state->state.scan_state.reorderer = make_uniq(*bind_data.order_options); + g_state->state.local_state.reorderer = make_uniq(*bind_data.order_options); + } + storage.InitializeParallelScan(context, g_state->state); if (!input.CanRemoveFilterColumns()) { return std::move(g_state); @@ -350,7 +411,8 @@ unique_ptr DuckTableScanInitGlobal(ClientContext &cont unique_ptr DuckIndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, const TableScanBindData &bind_data, set &row_ids) { auto g_state = make_uniq(context, input.bind_data.get()); - g_state->finished = row_ids.empty() ? true : false; + g_state->finished_first_phase = row_ids.empty() ? true : false; + g_state->started_last_phase = false; if (!row_ids.empty()) { auto row_id_ptr = g_state->arena.AllocateAligned(row_ids.size() * sizeof(row_t)); @@ -405,6 +467,9 @@ bool ExtractComparisonsAndInFilters(TableFilter &filter, vector()); return true; } + case TableFilterType::BLOOM_FILTER: { + return true; // We can't use it for finding cmp/in filters, but we can just ignore it + } case TableFilterType::CONJUNCTION_AND: { auto &conjunction_and = filter.Cast(); for (idx_t i = 0; i < conjunction_and.child_filters.size(); i++) { @@ -484,8 +549,8 @@ vector> ExtractFilterExpressions(const ColumnDefinition & return expressions; } -bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInput &input, TableFilterSet &filter_set, - idx_t max_count, set &row_ids) { +bool TryScanIndex(ART &art, IndexEntry &entry, const ColumnList &column_list, TableFunctionInitInput &input, + TableFilterSet &filter_set, idx_t max_count, set &row_ids) { // FIXME: No support for index scans on compound ARTs. // See note above on multi-filter support. if (art.unbound_expressions.size() > 1) { @@ -553,17 +618,27 @@ bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInpu return false; } + lock_guard guard(entry.lock); + vector> arts_to_scan; + arts_to_scan.push_back(art); + if (entry.deleted_rows_in_use) { + arts_to_scan.push_back(entry.deleted_rows_in_use->Cast()); + } + auto expressions = ExtractFilterExpressions(col, filter->second, storage_index.GetIndex()); for (const auto &filter_expr : expressions) { - auto scan_state = art.TryInitializeScan(*index_expr, *filter_expr); - if (!scan_state) { - return false; - } + for (auto &art_ref : arts_to_scan) { + auto &art_to_scan = art_ref.get(); + auto scan_state = art_to_scan.TryInitializeScan(*index_expr, *filter_expr); + if (!scan_state) { + return false; + } - // Check if we can use an index scan, and already retrieve the matching row ids. - if (!art.Scan(*scan_state, max_count, row_ids)) { - row_ids.clear(); - return false; + // Check if we can use an index scan, and already retrieve the matching row ids. + if (!art_to_scan.Scan(*scan_state, max_count, row_ids)) { + row_ids.clear(); + return false; + } } } return true; @@ -592,9 +667,6 @@ unique_ptr TableScanInitGlobal(ClientContext &context, return DuckTableScanInitGlobal(context, input, storage, bind_data); } - // The checkpoint lock ensures that we do not checkpoint while scanning this table. - auto &transaction = DuckTransaction::Get(context, storage.db); - auto checkpoint_lock = transaction.SharedLockTable(*storage.GetDataTableInfo()); auto &info = storage.GetDataTableInfo(); auto &indexes = info->GetIndexes(); if (indexes.Empty()) { @@ -613,13 +685,14 @@ unique_ptr TableScanInitGlobal(ClientContext &context, set row_ids; info->BindIndexes(context, ART::TYPE_NAME); - info->GetIndexes().Scan([&](Index &index) { + info->GetIndexes().ScanEntries([&](IndexEntry &entry) { + auto &index = *entry.index; if (index.GetIndexType() != ART::TYPE_NAME) { return false; } D_ASSERT(index.IsBound()); auto &art = index.Cast(); - index_scan = TryScanIndex(art, column_list, input, filter_set, max_count, row_ids); + index_scan = TryScanIndex(art, entry, column_list, input, filter_set, max_count, row_ids); return index_scan; }); @@ -664,7 +737,6 @@ OperatorPartitionData TableScanGetPartitionData(ClientContext &context, TableFun vector TableScanGetPartitionStats(ClientContext &context, GetPartitionStatsInput &input) { auto &bind_data = input.bind_data->Cast(); - vector result; auto &duck_table = bind_data.table.Cast(); auto &storage = duck_table.GetStorage(); return storage.GetPartitionStats(context); @@ -740,6 +812,11 @@ vector TableScanGetRowIdColumns(ClientContext &context, optional_ptr order_options, optional_ptr bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + bind_data.order_options = std::move(order_options); +} + TableFunction TableScanFunction::GetFunction() { TableFunction scan_function("seq_scan", {}, TableScanFunc); scan_function.init_local = TableScanInitLocal; @@ -763,6 +840,7 @@ TableFunction TableScanFunction::GetFunction() { scan_function.pushdown_expression = TableScanPushdownExpression; scan_function.get_virtual_columns = TableScanGetVirtualColumns; scan_function.get_row_id_columns = TableScanGetRowIdColumns; + scan_function.set_scan_order = SetScanOrder; return scan_function; } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index c3f5f0a6b..75c32b5cc 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev383" +#define DUCKDB_PATCH_VERSION "0-dev4438" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev383" +#define DUCKDB_VERSION "v1.5.0-dev4438" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "07d170f87e" +#define DUCKDB_SOURCE_ID "e3080f5eeb" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" @@ -91,6 +91,9 @@ const char *DuckDB::ReleaseCodename() { if (StringUtil::StartsWith(DUCKDB_VERSION, "v1.4.")) { return "Andium"; } + if (StringUtil::StartsWith(DUCKDB_VERSION, "v1.5.")) { + return "Variegata"; + } // add new version names here // we should not get here, but let's not fail because of it because tags on forks can be whatever diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp index 310f75b58..a5ac3ba6d 100644 --- a/src/duckdb/src/function/table_function.cpp +++ b/src/duckdb/src/function/table_function.cpp @@ -14,11 +14,26 @@ PartitionStatistics::PartitionStatistics() : row_start(0), count(0), count_type( TableFunctionInfo::~TableFunctionInfo() { } -TableFunction::TableFunction(string name, vector arguments, table_function_t function, +TableFunction::TableFunction(string name, const vector &arguments, table_function_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments)), bind(bind), bind_replace(nullptr), - bind_operator(nullptr), init_global(init_global), init_local(init_local), function(function), + : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), + bind_operator(nullptr), init_global(init_global), init_local(init_local), function(function_), + in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), + cardinality(nullptr), pushdown_complex_filter(nullptr), pushdown_expression(nullptr), to_string(nullptr), + dynamic_to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), get_bind_info(nullptr), + type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), + get_partition_info(nullptr), get_partition_stats(nullptr), get_virtual_columns(nullptr), + get_row_id_columns(nullptr), set_scan_order(nullptr), serialize(nullptr), deserialize(nullptr), + projection_pushdown(false), filter_pushdown(false), filter_prune(false), sampling_pushdown(false), + late_materialization(false) { +} + +TableFunction::TableFunction(string name, const vector &arguments, std::nullptr_t function_, + table_function_bind_t bind, table_function_init_global_t init_global, + table_function_init_local_t init_local) + : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), + bind_operator(nullptr), init_global(init_global), init_local(init_local), function(nullptr), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), pushdown_complex_filter(nullptr), pushdown_expression(nullptr), to_string(nullptr), dynamic_to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), get_bind_info(nullptr), @@ -28,15 +43,44 @@ TableFunction::TableFunction(string name, vector arguments, table_f filter_pushdown(false), filter_prune(false), sampling_pushdown(false), late_materialization(false) { } -TableFunction::TableFunction(const vector &arguments, table_function_t function, +TableFunction::TableFunction(const vector &arguments, table_function_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) - : TableFunction(string(), arguments, function, bind, init_global, init_local) { + : TableFunction("", arguments, function_, bind, init_global, init_local) { +} + +TableFunction::TableFunction(const vector &arguments, std::nullptr_t function_, table_function_bind_t bind, + table_function_init_global_t init_global, table_function_init_local_t init_local) + : TableFunction("", arguments, function_, bind, init_global, init_local) { } TableFunction::TableFunction() : TableFunction("", {}, nullptr, nullptr, nullptr, nullptr) { } +bool TableFunction::operator==(const TableFunction &rhs) const { + return name == rhs.name && arguments == rhs.arguments && varargs == rhs.varargs && bind == rhs.bind && + bind_replace == rhs.bind_replace && bind_operator == rhs.bind_operator && init_global == rhs.init_global && + init_local == rhs.init_local && function == rhs.function && in_out_function == rhs.in_out_function && + in_out_function_final == rhs.in_out_function_final && statistics == rhs.statistics && + dependency == rhs.dependency && cardinality == rhs.cardinality && + pushdown_complex_filter == rhs.pushdown_complex_filter && pushdown_expression == rhs.pushdown_expression && + to_string == rhs.to_string && dynamic_to_string == rhs.dynamic_to_string && + table_scan_progress == rhs.table_scan_progress && get_partition_data == rhs.get_partition_data && + get_bind_info == rhs.get_bind_info && type_pushdown == rhs.type_pushdown && + get_multi_file_reader == rhs.get_multi_file_reader && supports_pushdown_type == rhs.supports_pushdown_type && + get_partition_info == rhs.get_partition_info && get_partition_stats == rhs.get_partition_stats && + get_virtual_columns == rhs.get_virtual_columns && get_row_id_columns == rhs.get_row_id_columns && + serialize == rhs.serialize && deserialize == rhs.deserialize && + verify_serialization == rhs.verify_serialization && projection_pushdown == rhs.projection_pushdown && + filter_pushdown == rhs.filter_pushdown && filter_prune == rhs.filter_prune && + sampling_pushdown == rhs.sampling_pushdown && late_materialization == rhs.late_materialization && + global_initialization == rhs.global_initialization; +} + +bool TableFunction::operator!=(const TableFunction &rhs) const { + return !(*this == rhs); +} + bool TableFunction::Equal(const TableFunction &rhs) const { // number of types if (this->arguments.size() != rhs.arguments.size()) { @@ -56,4 +100,22 @@ bool TableFunction::Equal(const TableFunction &rhs) const { return true; // they are equal } +bool ExtractSourceResultType(AsyncResultType in, SourceResultType &out) { + switch (in) { + case AsyncResultType::IMPLICIT: + case AsyncResultType::INVALID: + return false; + case AsyncResultType::HAVE_MORE_OUTPUT: + out = SourceResultType::HAVE_MORE_OUTPUT; + break; + case AsyncResultType::FINISHED: + out = SourceResultType::FINISHED; + break; + case AsyncResultType::BLOCKED: + out = SourceResultType::BLOCKED; + break; + } + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/function/udf_function.cpp b/src/duckdb/src/function/udf_function.cpp index 3c03dbbe3..55ba9385f 100644 --- a/src/duckdb/src/function/udf_function.cpp +++ b/src/duckdb/src/function/udf_function.cpp @@ -9,10 +9,9 @@ namespace duckdb { void UDFWrapper::RegisterFunction(string name, vector args, LogicalType ret_type, scalar_function_t udf_function, ClientContext &context, LogicalType varargs) { - ScalarFunction scalar_function(std::move(name), std::move(args), std::move(ret_type), std::move(udf_function)); scalar_function.varargs = std::move(varargs); - scalar_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + scalar_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); CreateScalarFunctionInfo info(scalar_function); info.schema = DEFAULT_SCHEMA; context.RegisterFunction(info); diff --git a/src/duckdb/src/function/variant/variant_shredding.cpp b/src/duckdb/src/function/variant/variant_shredding.cpp new file mode 100644 index 000000000..ee69da0e9 --- /dev/null +++ b/src/duckdb/src/function/variant/variant_shredding.cpp @@ -0,0 +1,342 @@ +#include "duckdb/function/variant/variant_shredding.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +namespace duckdb { + +static void WriteShreddedPrimitive(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count, idx_t type_size) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + + auto byte_offset = variant.GetByteOffset(row, value_index); + auto &data = variant.GetData(row); + auto value_ptr = data.GetData(); + auto result_offset = type_size * result_row; + memcpy(result_data + result_offset, value_ptr + byte_offset, type_size); + } +} + +template +static void WriteShreddedDecimal(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && variant.GetTypeId(row, value_index) == VariantLogicalType::DECIMAL); + + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, value_index); + D_ASSERT(decimal_data.width <= DecimalWidth::max); + auto result_offset = sizeof(T) * result_row; + memcpy(result_data + result_offset, decimal_data.value_ptr, sizeof(T)); + } +} + +static bool IsVariantStringType(VariantLogicalType type_id) { + switch (type_id) { + case VariantLogicalType::GEOMETRY: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: + case VariantLogicalType::BIGNUM: + return true; + default: + return false; + } +} + +static void WriteShreddedString(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && IsVariantStringType(variant.GetTypeId(row, value_index))); + + auto string_data = VariantUtils::DecodeStringData(variant, row, value_index); + result_data[result_row] = StringVector::AddStringOrBlob(result, string_data); + } +} + +static void WriteShreddedBoolean(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + auto type_id = variant.GetTypeId(row, value_index); + D_ASSERT(type_id == VariantLogicalType::BOOL_FALSE || type_id == VariantLogicalType::BOOL_TRUE); + + result_data[result_row] = type_id == VariantLogicalType::BOOL_TRUE; + } +} + +void VariantShredding::WriteTypedPrimitiveValues(UnifiedVariantVectorData &variant, Vector &result, + const SelectionVector &sel, const SelectionVector &value_index_sel, + const SelectionVector &result_sel, idx_t count) { + auto &type = result.GetType(); + D_ASSERT(!type.IsNested()); + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_NS: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::INTERVAL: + case LogicalTypeId::UUID: { + const auto physical_type = type.InternalType(); + WriteShreddedPrimitive(variant, result, sel, value_index_sel, result_sel, count, GetTypeIdSize(physical_type)); + break; + } + case LogicalTypeId::DECIMAL: { + const auto physical_type = type.InternalType(); + switch (physical_type) { + //! DECIMAL2 (doesn't exist in Parquet for some reason) + case PhysicalType::INT16: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL4 + case PhysicalType::INT32: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL8 + case PhysicalType::INT64: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL16 + case PhysicalType::INT128: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on column of type '%s'", type.ToString()); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::BIGNUM: + case LogicalTypeId::GEOMETRY: + case LogicalTypeId::BIT: + case LogicalTypeId::VARCHAR: { + WriteShreddedString(variant, result, sel, value_index_sel, result_sel, count); + break; + } + case LogicalTypeId::BOOLEAN: + WriteShreddedBoolean(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on type: %s", type.ToString()); + } +} + +void VariantShredding::WriteTypedObjectValues(UnifiedVariantVectorData &variant, Vector &result, + const SelectionVector &sel, const SelectionVector &value_index_sel, + const SelectionVector &result_sel, idx_t count) { + auto &type = result.GetType(); + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + + auto &validity = FlatVector::Validity(result); + (void)validity; + + //! Collect the nested data for the objects + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + //! When we're shredding an object, the top-level struct of it should always be valid + D_ASSERT(validity.RowIsValid(result_sel[i])); + auto value_index = value_index_sel[i]; + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::OBJECT); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + } + + auto &shredded_types = StructType::GetChildTypes(type); + auto &shredded_fields = StructVector::GetEntries(result); + D_ASSERT(shredded_types.size() == shredded_fields.size()); + + SelectionVector child_values_indexes; + SelectionVector child_row_sel; + SelectionVector child_result_sel; + child_values_indexes.Initialize(count); + child_row_sel.Initialize(count); + child_result_sel.Initialize(count); + + for (idx_t child_idx = 0; child_idx < shredded_types.size(); child_idx++) { + auto &child_vec = *shredded_fields[child_idx]; + D_ASSERT(child_vec.GetType() == shredded_types[child_idx].second); + + //! Prepare the path component to perform the lookup for + auto &key = shredded_types[child_idx].first; + VariantPathComponent path_component; + path_component.lookup_mode = VariantChildLookupMode::BY_KEY; + path_component.key = key; + + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, path_component, sel, child_values_indexes, lookup_validity, + nested_data.get(), count); + + if (!lookup_validity.AllValid()) { + auto &child_variant_vectors = StructVector::GetEntries(child_vec); + + //! For some of the rows the field is missing, adjust the selection vector to exclude these rows. + idx_t child_count = 0; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + //! The field is missing, set it to null + FlatVector::SetNull(*child_variant_vectors[0], result_sel[i], true); + if (child_variant_vectors.size() >= 2) { + FlatVector::SetNull(*child_variant_vectors[1], result_sel[i], true); + } + continue; + } + + child_row_sel[child_count] = sel[i]; + child_values_indexes[child_count] = child_values_indexes[i]; + child_result_sel[child_count] = result_sel[i]; + child_count++; + } + + if (child_count) { + //! If not all rows are missing this field, write the values for it + WriteVariantValues(variant, child_vec, child_row_sel, child_values_indexes, child_result_sel, + child_count); + } + } else { + WriteVariantValues(variant, child_vec, &sel, child_values_indexes, result_sel, count); + } + } +} + +void VariantShredding::WriteTypedArrayValues(UnifiedVariantVectorData &variant, Vector &result, + const SelectionVector &sel, const SelectionVector &value_index_sel, + const SelectionVector &result_sel, idx_t count) { + auto list_data = FlatVector::GetData(result); + + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + + idx_t total_offset = 0; + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto value_index = value_index_sel[i]; + auto result_row = result_sel[i]; + + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::ARRAY); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + + list_entry_t list_entry; + list_entry.length = nested_data[i].child_count; + list_entry.offset = total_offset; + list_data[result_row] = list_entry; + + total_offset += nested_data[i].child_count; + } + ListVector::Reserve(result, total_offset); + ListVector::SetListSize(result, total_offset); + + SelectionVector child_sel; + child_sel.Initialize(total_offset); + + SelectionVector child_value_index_sel; + child_value_index_sel.Initialize(total_offset); + + SelectionVector child_result_sel; + child_result_sel.Initialize(total_offset); + + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + + auto &array_data = nested_data[i]; + auto &entry = list_data[result_row]; + for (idx_t j = 0; j < entry.length; j++) { + auto offset = entry.offset + j; + child_sel[offset] = row; + child_value_index_sel[offset] = variant.GetValuesIndex(row, array_data.children_idx + j); + child_result_sel[offset] = NumericCast(offset); + } + } + + auto &child_vector = ListVector::GetEntry(result); + WriteVariantValues(variant, child_vector, child_sel, child_value_index_sel, child_result_sel, total_offset); +} + +void VariantShredding::WriteTypedValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto &type = result.GetType(); + + if (type.id() == LogicalTypeId::STRUCT) { + //! Shredded OBJECT + WriteTypedObjectValues(variant, result, sel, value_index_sel, result_sel, count); + } else if (type.id() == LogicalTypeId::LIST) { + //! Shredded ARRAY + WriteTypedArrayValues(variant, result, sel, value_index_sel, result_sel, count); + } else { + //! Primitive types + WriteTypedPrimitiveValues(variant, result, sel, value_index_sel, result_sel, count); + } +} + +VariantShreddingState::VariantShreddingState(const LogicalType &type, idx_t total_count) + : type(type), shredded_sel(total_count), values_index_sel(total_count), result_sel(total_count) { +} + +bool VariantShreddingState::ValueIsShredded(UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index) { + auto type_id = variant.GetTypeId(row, values_index); + if (!GetVariantTypes().count(type_id)) { + return false; + } + if (type_id == VariantLogicalType::DECIMAL) { + auto physical_type = type.InternalType(); + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + auto decimal_physical_type = decimal_data.GetPhysicalType(); + return physical_type == decimal_physical_type; + } + return true; +} + +void VariantShreddingState::SetShredded(uint32_t row, uint32_t values_index, uint32_t result_idx) { + shredded_sel[count] = row; + values_index_sel[count] = values_index; + result_sel[count] = result_idx; + count++; +} + +case_insensitive_string_set_t VariantShreddingState::ObjectFields() { + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + case_insensitive_string_set_t res; + auto &child_types = StructType::GetChildTypes(type); + for (auto &entry : child_types) { + auto &type = entry.first; + res.emplace(type.c_str(), static_cast(type.size())); + } + return res; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp index 95c8a5059..ceb906345 100644 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ b/src/duckdb/src/function/window/window_aggregate_function.cpp @@ -32,10 +32,10 @@ static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &w if (wexpr.aggregate && ClientConfig::GetConfig(context).enable_optimizer) { const auto &aggr = wexpr.aggregate; auto &arg_orders = wexpr.arg_orders; - if (aggr->distinct_dependent != AggregateDistinctDependent::DISTINCT_DEPENDENT) { + if (aggr->GetDistinctDependent() != AggregateDistinctDependent::DISTINCT_DEPENDENT) { wexpr.distinct = false; } - if (aggr->order_dependent != AggregateOrderDependent::ORDER_DEPENDENT) { + if (aggr->GetOrderDependent() != AggregateOrderDependent::ORDER_DEPENDENT) { arg_orders.clear(); } else { // If the argument order is prefix of the partition ordering, @@ -52,7 +52,6 @@ static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &w WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared, WindowAggregationMode mode) : WindowExecutor(SimplifyWindowedAggregate(wexpr, client), shared), mode(mode) { - // Force naive for SEPARATE mode or for (currently!) unsupported functionality if (!ClientConfig::GetConfig(client).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { if (!WindowNaiveAggregator::CanAggregate(wexpr)) { @@ -111,7 +110,6 @@ class WindowAggregateExecutorLocalState : public WindowExecutorBoundsLocalState const WindowAggregator &aggregator) : WindowExecutorBoundsLocalState(context, gstate.Cast()), filter_executor(context.client) { - auto &gastate = gstate.Cast(); aggregator_state = aggregator.GetLocalState(context, *gastate.gsink); diff --git a/src/duckdb/src/function/window/window_aggregate_states.cpp b/src/duckdb/src/function/window/window_aggregate_states.cpp index 7c279cf5e..7db0d5d03 100644 --- a/src/duckdb/src/function/window/window_aggregate_states.cpp +++ b/src/duckdb/src/function/window/window_aggregate_states.cpp @@ -2,8 +2,9 @@ namespace duckdb { -WindowAggregateStates::WindowAggregateStates(const AggregateObject &aggr) - : aggr(aggr), state_size(aggr.function.state_size(aggr.function)), allocator(Allocator::DefaultAllocator()) { +WindowAggregateStates::WindowAggregateStates(ClientContext &client, const AggregateObject &aggr) + : client(client), aggr(aggr), state_size(aggr.function.GetStateSizeCallback()(aggr.function)), + allocator(Allocator::Get(client)) { } void WindowAggregateStates::Initialize(idx_t count) { @@ -18,21 +19,21 @@ void WindowAggregateStates::Initialize(idx_t count) { for (idx_t i = 0; i < count; ++i, state_ptr += state_size) { state_f_data[i] = state_ptr; - aggr.function.initialize(aggr.function, state_ptr); + aggr.function.GetStateInitCallback()(aggr.function, state_ptr); } // Prevent conversion of results to constants statef->SetVectorType(VectorType::FLAT_VECTOR); } -void WindowAggregateStates::Combine(WindowAggregateStates &target, AggregateCombineType combine_type) { +void WindowAggregateStates::Combine(WindowAggregateStates &target) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - aggr.function.combine(*statef, *target.statef, aggr_input_data, GetCount()); + aggr.function.GetStateCombineCallback()(*statef, *target.statef, aggr_input_data, GetCount()); } void WindowAggregateStates::Finalize(Vector &result) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(*statef, aggr_input_data, result, GetCount(), 0); + aggr.function.GetStateFinalizeCallback()(*statef, aggr_input_data, result, GetCount(), 0); } void WindowAggregateStates::Destroy() { @@ -41,8 +42,8 @@ void WindowAggregateStates::Destroy() { } AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - if (aggr.function.destructor) { - aggr.function.destructor(*statef, aggr_input_data, GetCount()); + if (aggr.function.HasStateDestructorCallback()) { + aggr.function.GetStateDestructorCallback()(*statef, aggr_input_data, GetCount()); } states.clear(); diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp index 3ac9c91c9..5e6f88ae5 100644 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ b/src/duckdb/src/function/window/window_aggregator.cpp @@ -10,9 +10,8 @@ namespace duckdb { // WindowAggregator //===--------------------------------------------------------------------===// WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr) - : wexpr(wexpr), aggr(wexpr), result_type(wexpr.return_type), state_size(aggr.function.state_size(aggr.function)), - exclude_mode(wexpr.exclude_clause) { - + : wexpr(wexpr), aggr(wexpr), result_type(wexpr.return_type), + state_size(aggr.function.GetStateSizeCallback()(aggr.function)), exclude_mode(wexpr.exclude_clause) { for (auto &child : wexpr.children) { arg_types.emplace_back(child->return_type); } @@ -30,9 +29,8 @@ WindowAggregator::~WindowAggregator() { WindowAggregatorGlobalState::WindowAggregatorGlobalState(ClientContext &client, const WindowAggregator &aggregator_p, idx_t group_count) - : client(client), allocator(Allocator::DefaultAllocator()), aggregator(aggregator_p), aggr(aggregator.wexpr), + : client(client), allocator(BufferAllocator::Get(client)), aggregator(aggregator_p), aggr(aggregator.wexpr), locals(0), finalized(0) { - if (aggr.filter) { // Start with all invalid and set the ones that pass filter_mask.Initialize(group_count, false); @@ -47,7 +45,7 @@ unique_ptr WindowAggregator::GetGlobalState(ClientContext &cont } WindowAggregatorLocalState::WindowAggregatorLocalState(ExecutionContext &context) - : allocator(Allocator::DefaultAllocator()) { + : allocator(BufferAllocator::Get(context.client)) { } void WindowAggregatorLocalState::Sink(ExecutionContext &context, WindowAggregatorGlobalState &gastate, diff --git a/src/duckdb/src/function/window/window_boundaries_state.cpp b/src/duckdb/src/function/window/window_boundaries_state.cpp index 84ae8abb2..84fbc7929 100644 --- a/src/duckdb/src/function/window/window_boundaries_state.cpp +++ b/src/duckdb/src/function/window/window_boundaries_state.cpp @@ -620,7 +620,6 @@ void WindowBoundariesState::PartitionEnd(DataChunk &bounds, idx_t row_idx, const void WindowBoundariesState::PeerBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, const ValidityMask &partition_mask, const ValidityMask &order_mask) { - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); // OVER() diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp index d35c90b4f..dfadfce1c 100644 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ b/src/duckdb/src/function/window/window_constant_aggregator.cpp @@ -30,12 +30,11 @@ class WindowConstantAggregatorGlobalState : public WindowAggregatorGlobalState { unique_ptr results; }; -WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(ClientContext &context, +WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(ClientContext &client, const WindowConstantAggregator &aggregator, idx_t group_count, const ValidityMask &partition_mask) - : WindowAggregatorGlobalState(context, aggregator, STANDARD_VECTOR_SIZE), statef(aggr) { - + : WindowAggregatorGlobalState(client, aggregator, STANDARD_VECTOR_SIZE), statef(client, aggr) { // Locate the partition boundaries if (partition_mask.AllValid()) { partition_offsets.emplace_back(0); @@ -104,8 +103,8 @@ class WindowConstantAggregatorLocalState : public WindowAggregatorLocalState { WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( ExecutionContext &context, const WindowConstantAggregatorGlobalState &gstate) - : WindowAggregatorLocalState(context), gstate(gstate), statep(Value::POINTER(0)), statef(gstate.statef.aggr), - partition(0) { + : WindowAggregatorLocalState(context), gstate(gstate), statep(Value::POINTER(0)), + statef(context.client, gstate.statef.aggr), partition(0) { matches.Initialize(); // Start the aggregates @@ -114,7 +113,7 @@ WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( statef.Initialize(partition_offsets.size() - 1); // Set up shared buffer - inputs.Initialize(Allocator::DefaultAllocator(), aggregator.arg_types); + inputs.Initialize(context.client, aggregator.arg_types); payload_chunk.InitializeEmpty(inputs.GetTypes()); gstate.locals++; @@ -201,7 +200,6 @@ BoundWindowExpression &WindowConstantAggregator::RebindAggregate(ClientContext & WindowConstantAggregator::WindowConstantAggregator(BoundWindowExpression &wexpr, WindowSharedExpressions &shared, ClientContext &context) : WindowAggregator(RebindAggregate(context, wexpr)) { - // We only need these values for Sink for (auto &child : wexpr.children) { child_idx.emplace_back(shared.RegisterSink(child)); @@ -239,7 +237,7 @@ void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChu payload_chunk.data[c].Reference(sink_chunk.data[child_idx[c]]); } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), statef.allocator); idx_t begin = 0; idx_t filter_idx = 0; auto partition_end = partition_offsets[partition + 1]; @@ -292,11 +290,13 @@ void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChu // Aggregate the filtered rows into a single state const auto count = inputs.size(); auto state = state_f_data[partition]; - if (aggr.function.simple_update) { - aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state, count); + if (aggr.function.HasStateSimpleUpdateCallback()) { + aggr.function.GetStateSimpleUpdateCallback()(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), + state, count); } else { state_p_data[0] = state_f_data[partition]; - aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); + aggr.function.GetStateUpdateCallback()(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, + count); } // Skip filtered rows too! diff --git a/src/duckdb/src/function/window/window_custom_aggregator.cpp b/src/duckdb/src/function/window/window_custom_aggregator.cpp index 993972ecd..69408634c 100644 --- a/src/duckdb/src/function/window/window_custom_aggregator.cpp +++ b/src/duckdb/src/function/window/window_custom_aggregator.cpp @@ -72,18 +72,18 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { WindowCustomAggregatorLocalState::WindowCustomAggregatorLocalState(ExecutionContext &context, const AggregateObject &aggr, const WindowExcludeMode exclude_mode) - : WindowAggregatorLocalState(context), aggr(aggr), state(aggr.function.state_size(aggr.function)), + : WindowAggregatorLocalState(context), aggr(aggr), state(aggr.function.GetStateSizeCallback()(aggr.function)), statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { // if we have a frame-by-frame method, share the single state - aggr.function.initialize(aggr.function, state.data()); + aggr.function.GetStateInitCallback()(aggr.function, state.data()); InitSubFrames(frames, exclude_mode); } WindowCustomAggregatorLocalState::~WindowCustomAggregatorLocalState() { - if (aggr.function.destructor) { + if (aggr.function.HasStateDestructorCallback()) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.destructor(statef, aggr_input_data, 1); + aggr.function.GetStateDestructorCallback()(statef, aggr_input_data, 1); } } @@ -115,13 +115,13 @@ void WindowCustomAggregator::Finalize(ExecutionContext &context, CollectionPtr c filter_mask.Pack(filter_packed, filter_mask.Capacity()); gcsink.glstate = GetLocalState(context, gcsink); - if (aggr.function.window_init) { + if (aggr.function.HasWindowInitCallback()) { auto &gcstate = gcsink.glstate->Cast(); WindowPartitionInput partition(context, inputs, count, child_idx, all_valids, filter_packed, stats, sink.interrupt_state); AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); - aggr.function.window_init(aggr_input_data, partition, gcstate.state.data()); + aggr.function.GetWindowInitCallback()(aggr_input_data, partition, gcstate.state.data()); } ++gcsink.finalized; @@ -153,7 +153,8 @@ void WindowCustomAggregator::Evaluate(ExecutionContext &context, const DataChunk EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { // Extract the range AggregateInputData aggr_input_data(aggr.GetFunctionData(), lcstate.allocator); - aggr.function.window(aggr_input_data, partition, gstate_p, lcstate.state.data(), frames, result, i); + aggr.function.GetWindowCallback()(aggr_input_data, partition, gstate_p, lcstate.state.data(), frames, result, + i); }); } diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp index 063e25a80..315fb0974 100644 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ b/src/duckdb/src/function/window/window_distinct_aggregator.cpp @@ -73,12 +73,6 @@ class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { //! Create a new local sort optional_ptr InitializeLocalSort(ExecutionContext &context) const; - ArenaAllocator &CreateTreeAllocator() const { - lock_guard tree_lock(lock); - tree_allocators.emplace_back(make_uniq(Allocator::DefaultAllocator())); - return *tree_allocators.back(); - } - bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate); //! The tree allocators. @@ -123,8 +117,7 @@ WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientC const WindowDistinctAggregator &aggregator, idx_t group_count) : WindowAggregatorGlobalState(client, aggregator, group_count), stage(WindowDistinctSortStage::INIT), - tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), levels_flat_native(aggr) { - + tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), levels_flat_native(client, aggr) { // 1: functionComputePrevIdcs(𝑖𝑛) // 2: sorted ← [] // We sort the aggregate arguments and use the partition index as a tie-breaker. @@ -136,7 +129,7 @@ WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientC vector orders; for (const auto &type : sort_types) { auto expr = make_uniq(type, orders.size()); - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr))); + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr)); sort_cols.emplace_back(sort_cols.size()); } @@ -199,8 +192,6 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { void Evaluate(ExecutionContext &context, const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx); - //! The thread-local allocator for building the tree - ArenaAllocator &tree_allocator; //! Thread-local sorting data optional_ptr local_sink; //! Finalize stage @@ -236,12 +227,12 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( ExecutionContext &context, const WindowDistinctAggregatorGlobalState &gdstate) - : WindowAggregatorLocalState(context), tree_allocator(gdstate.CreateTreeAllocator()), - update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), target_v(LogicalType::POINTER), gdstate(gdstate), - statef(gdstate.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { + : WindowAggregatorLocalState(context), update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), + target_v(LogicalType::POINTER), gdstate(gdstate), statef(context.client, gdstate.aggr), + statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { InitSubFrames(frames, gdstate.aggregator.exclude_mode); - sort_chunk.Initialize(Allocator::DefaultAllocator(), gdstate.sort_types); + sort_chunk.Initialize(context.client, gdstate.sort_types); gdstate.locals++; } @@ -297,7 +288,7 @@ void WindowDistinctAggregatorLocalState::Finalize(ExecutionContext &context, Win WindowAggregatorLocalState::Finalize(context, gastate, collection); //! Input data chunk, used for leaf segment aggregation - leaves.Initialize(Allocator::DefaultAllocator(), cursor->chunk.GetTypes()); + leaves.Initialize(context.client, cursor->chunk.GetTypes()); sel.Initialize(); } @@ -547,7 +538,7 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti auto &leaves = ldastate.leaves; auto &sel = ldastate.sel; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), ldastate.tree_allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), ldastate.allocator); //! The states to update auto &update_v = ldastate.update_v; @@ -583,11 +574,12 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti // Push the updates first so they propagate leaves.Reference(inputs); leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); + aggr.function.GetStateUpdateCallback()(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), + update_v, nupdate); nupdate = 0; // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); + aggr.function.GetStateCombineCallback()(source_v, target_v, aggr_input_data, ncombine); ncombine = 0; // Move the update into range. @@ -613,11 +605,12 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti // Push the updates first so they propagate leaves.Reference(inputs); leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); + aggr.function.GetStateUpdateCallback()(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, + nupdate); nupdate = 0; // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); + aggr.function.GetStateCombineCallback()(source_v, target_v, aggr_input_data, ncombine); ncombine = 0; } } @@ -627,11 +620,12 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti // Push the updates leaves.Reference(inputs); leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); + aggr.function.GetStateUpdateCallback()(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, + nupdate); nupdate = 0; // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); + aggr.function.GetStateCombineCallback()(source_v, target_v, aggr_input_data, ncombine); ncombine = 0; } @@ -644,9 +638,9 @@ void WindowDistinctAggregatorLocalState::FlushStates() { } const auto &aggr = gdstate.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), tree_allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); statel.Verify(flush_count); - aggr.function.combine(statel, statep, aggr_input_data, flush_count); + aggr.function.GetStateCombineCallback()(statel, statep, aggr_input_data, flush_count); flush_count = 0; } @@ -704,7 +698,6 @@ unique_ptr WindowDistinctAggregator::GetLocalState(ExecutionCont void WindowDistinctAggregator::Evaluate(ExecutionContext &context, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { - const auto &gdstate = sink.global_state.Cast(); auto &ldstate = sink.local_state.Cast(); ldstate.Evaluate(context, gdstate, bounds, result, count, row_idx); diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index 6af3d0e5b..5943e6228 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -1,5 +1,8 @@ #include "duckdb/function/window/window_merge_sort_tree.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" #include #include diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index a01e4813c..15750b26f 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -13,7 +13,6 @@ namespace duckdb { //===--------------------------------------------------------------------===// WindowNaiveAggregator::WindowNaiveAggregator(const WindowAggregateExecutor &executor, WindowSharedExpressions &shared) : WindowAggregator(executor.wexpr, shared), executor(executor) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterCollection(order.expression, false)); } @@ -166,7 +165,8 @@ void WindowNaiveLocalState::FlushStates(const WindowAggregatorGlobalState &gsink const auto &aggr = gsink.aggr; AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), statep, flush_count); + aggr.function.GetStateUpdateCallback()(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), statep, + flush_count); flush_count = 0; } @@ -235,7 +235,7 @@ void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggr WindowAggregator::EvaluateSubFrames(bounds, aggregator.exclude_mode, count, row_idx, frames, [&](idx_t rid) { auto agg_state = fdata[rid]; - aggr.function.initialize(aggr.function, agg_state); + aggr.function.GetStateInitCallback()(aggr.function, agg_state); // Reset the DISTINCT hash table row_set.clear(); @@ -352,11 +352,11 @@ void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggr // Finalise the result aggregates and write to the result AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); + aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); + if (aggr.function.HasStateDestructorCallback()) { + aggr.function.GetStateDestructorCallback()(statef, aggr_input_data, count); } } diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index af70521a0..90a0ee17f 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rank_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { @@ -55,6 +56,7 @@ class WindowPeerLocalState : public WindowExecutorBoundsLocalState { void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); + idx_t row_idx = DConstants::INVALID_INDEX; uint64_t dense_rank = 1; uint64_t rank_equal = 0; uint64_t rank = 1; @@ -103,7 +105,6 @@ void WindowPeerLocalState::NextRank(idx_t partition_begin, idx_t peer_begin, idx //===--------------------------------------------------------------------===// WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } @@ -181,48 +182,57 @@ void WindowDenseRankExecutor::EvaluateInternal(ExecutionContext &context, DataCh auto rdata = FlatVector::GetData(result); // Reset to "previous" row - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - // The previous dense rank is the number of order mask bits in [partition_begin, row_idx) - lpeer.dense_rank = 0; - - auto order_begin = partition_begin[0]; - idx_t begin_idx; - idx_t begin_offset; - order_mask.GetEntryIndex(order_begin, begin_idx, begin_offset); - - auto order_end = row_idx; - idx_t end_idx; - idx_t end_offset; - order_mask.GetEntryIndex(order_end, end_idx, end_offset); - - // If they are in the same entry, just loop - if (begin_idx == end_idx) { - const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < end_offset; ++begin_offset) { - lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - } - } else { - // Count the ragged bits at the start of the partition - if (begin_offset) { + // Resetting is slow because we have to rescan the mask. + // So check whether we are just picking up where we left off. + // This is common because the main window operator + // evaluates maximally sized runs for each hash group. + if (lpeer.row_idx != row_idx) { + lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; + lpeer.rank_equal = (row_idx - peer_begin[0]); + + // The previous dense rank is the number of order mask bits in [partition_begin, row_idx) + lpeer.dense_rank = 0; + + auto order_begin = partition_begin[0]; + idx_t begin_idx; + idx_t begin_offset; + order_mask.GetEntryIndex(order_begin, begin_idx, begin_offset); + + auto order_end = row_idx; + idx_t end_idx; + idx_t end_offset; + order_mask.GetEntryIndex(order_end, end_idx, end_offset); + + // If they are in the same entry, just loop + if (begin_idx == end_idx) { const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < order_mask.BITS_PER_VALUE; ++begin_offset) { + for (; begin_offset < end_offset; ++begin_offset) { lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - ++order_begin; } - ++begin_idx; - } + } else { + // Count the ragged bits at the start of the partition + if (begin_offset) { + const auto entry = order_mask.GetValidityEntry(begin_idx); + for (; begin_offset < order_mask.BITS_PER_VALUE; ++begin_offset) { + lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); + ++order_begin; + } + ++begin_idx; + } - // Count the the aligned bits. - ValidityMask tail_mask(order_mask.GetData() + begin_idx, end_idx - begin_idx); - lpeer.dense_rank += tail_mask.CountValid(order_end - order_begin); + // Count the the aligned bits. + ValidityMask tail_mask(order_mask.GetData() + begin_idx, end_idx - begin_idx); + lpeer.dense_rank += tail_mask.CountValid(order_end - order_begin); + } } for (idx_t i = 0; i < count; ++i, ++row_idx) { lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); rdata[i] = NumericCast(lpeer.dense_rank); } + + // Remember where we left off + lpeer.row_idx = row_idx; } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index f0929d642..27e7adecc 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rownumber_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { @@ -91,7 +92,6 @@ void WindowRowNumberLocalState::Finalize(ExecutionContext &context, CollectionPt //===--------------------------------------------------------------------===// WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } @@ -141,7 +141,6 @@ void WindowRowNumberExecutor::EvaluateInternal(ExecutionContext &context, DataCh //===--------------------------------------------------------------------===// WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowRowNumberExecutor(wexpr, shared) { - // NTILE has one argument ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); } diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index f62a0a856..b5e26cb99 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -32,12 +32,6 @@ class WindowSegmentTreeGlobalState : public WindowAggregatorGlobalState { WindowSegmentTreeGlobalState(ClientContext &context, const WindowSegmentTree &aggregator, idx_t group_count); - ArenaAllocator &CreateTreeAllocator() { - lock_guard tree_lock(lock); - tree_allocators.emplace_back(make_uniq(Allocator::DefaultAllocator())); - return *tree_allocators.back(); - } - //! The owning aggregator const WindowSegmentTree &tree; //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes @@ -160,11 +154,10 @@ void WindowSegmentTree::Finalize(ExecutionContext &context, CollectionPtr collec WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, unique_ptr cursor_p, const ValidityArray &filter_mask) : allocator(allocator), aggr(aggr), - order_insensitive(aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT), - filter_mask(filter_mask), state_size(aggr.function.state_size(aggr.function)), + order_insensitive(aggr.function.GetOrderDependent() == AggregateOrderDependent::NOT_ORDER_DEPENDENT), + filter_mask(filter_mask), state_size(aggr.function.GetStateSizeCallback()(aggr.function)), state(state_size * STANDARD_VECTOR_SIZE), cursor(std::move(cursor_p)), statep(LogicalType::POINTER), statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { - auto &inputs = cursor->chunk; if (inputs.ColumnCount() > 0) { leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); @@ -204,11 +197,12 @@ void WindowSegmentTreePart::FlushStates(bool combining) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); if (combining) { statel.Verify(flush_count); - aggr.function.combine(statel, statep, aggr_input_data, flush_count); + aggr.function.GetStateCombineCallback()(statel, statep, aggr_input_data, flush_count); } else { auto &scanned = cursor->chunk; leaves.Slice(scanned, filter_sel, flush_count); - aggr.function.update(&leaves.data[0], aggr_input_data, leaves.ColumnCount(), statep, flush_count); + aggr.function.GetStateUpdateCallback()(&leaves.data[0], aggr_input_data, leaves.ColumnCount(), statep, + flush_count); } flush_count = 0; @@ -216,7 +210,7 @@ void WindowSegmentTreePart::FlushStates(bool combining) { void WindowSegmentTreePart::Combine(WindowSegmentTreePart &other, idx_t count) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.combine(other.statef, statef, aggr_input_data, count); + aggr.function.GetStateCombineCallback()(other.statef, statef, aggr_input_data, count); } void WindowSegmentTreePart::ExtractFrame(idx_t begin, idx_t end, data_ptr_t state_ptr) { @@ -287,18 +281,17 @@ void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTreeGlobalStat void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { // Finalise the result aggregates and write to result if write_result is set AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); + aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); + if (aggr.function.HasStateDestructorCallback()) { + aggr.function.GetStateDestructorCallback()(statef, aggr_input_data, count); } } -WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &context, const WindowSegmentTree &aggregator, +WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &client, const WindowSegmentTree &aggregator, idx_t group_count) - : WindowAggregatorGlobalState(context, aggregator, group_count), tree(aggregator), levels_flat_native(aggr) { - + : WindowAggregatorGlobalState(client, aggregator, group_count), tree(aggregator), levels_flat_native(client, aggr) { D_ASSERT(!aggregator.wexpr.children.empty()); // compute space required to store internal nodes of segment tree @@ -349,7 +342,7 @@ void WindowSegmentTreeLocalState::Finalize(ExecutionContext &context, WindowAggr auto cursor = make_uniq(*collection, gastate.aggregator.child_idx); const auto leaf_count = collection->size(); auto &filter_mask = gstate.filter_mask; - WindowSegmentTreePart gtstate(gstate.CreateTreeAllocator(), gastate.aggr, std::move(cursor), filter_mask); + WindowSegmentTreePart gtstate(allocator, gastate.aggr, std::move(cursor), filter_mask); auto &levels_flat_native = gstate.levels_flat_native; const auto &levels_flat_start = gstate.levels_flat_start; @@ -464,7 +457,7 @@ void WindowSegmentTreePart::Initialize(idx_t count) { auto fdata = FlatVector::GetData(statef); for (idx_t rid = 0; rid < count; ++rid) { auto state_ptr = fdata[rid]; - aggr.function.initialize(aggr.function, state_ptr); + aggr.function.GetStateInitCallback()(aggr.function, state_ptr); } } @@ -570,7 +563,6 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTreeGlobalSta void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, const idx_t *bounds, idx_t count, idx_t row_idx, FramePart frame_part, FramePart leaf_part) { - auto fdata = FlatVector::GetData(statef); // For order-sensitive aggregates, we have to process the ragged leaves in two pieces. diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index 0258b7d6b..adf60be11 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -23,7 +23,6 @@ class WindowValueGlobalState : public WindowExecutorGlobalState { const ValidityMask &partition_mask, const ValidityMask &order_mask) : WindowExecutorGlobalState(client, executor, payload_count, partition_mask, order_mask), ignore_nulls(&all_valid), child_idx(executor.child_idx) { - if (!executor.arg_order_idx.empty()) { value_tree = make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, payload_count); @@ -139,7 +138,6 @@ void WindowValueLocalState::Finalize(ExecutionContext &context, CollectionPtr co //===--------------------------------------------------------------------===// WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowExecutor(wexpr, shared) { - for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } @@ -200,7 +198,6 @@ class WindowLeadLagGlobalState : public WindowValueGlobalState { const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) : WindowValueGlobalState(client, executor, payload_count, partition_mask, order_mask) { - if (value_tree) { use_framing = true; @@ -842,7 +839,6 @@ static fill_value_t GetFillValueFunction(const LogicalType &type) { WindowFillExecutor::WindowFillExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowValueExecutor(wexpr, shared) { - // We need the sort values for interpolation, so either use the range or the secondary ordering expression if (arg_order_idx.empty()) { // We use the range ordering, even if it has not been defined @@ -918,7 +914,6 @@ unique_ptr WindowFillExecutor::GetLocalState(ExecutionContext &c void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx, OperatorSinkInput &sink) const { - auto &lfstate = sink.local_state.Cast(); auto &cursor = *lfstate.cursor; diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index ccf5ad5ac..a23b16a00 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -255,6 +255,35 @@ typedef enum duckdb_file_flag { DUCKDB_FILE_FLAG_APPEND = 5, } duckdb_file_flag; +//! An enum over DuckDB's configuration option scopes. +//! This enum can be used to specify the default scope when creating a custom configuration option, +//! but it is also be used to determine the scope in which a configuration option is set when it is +//! changed or retrieved. +typedef enum duckdb_config_option_scope { + DUCKDB_CONFIG_OPTION_SCOPE_INVALID = 0, + // The option is set for the duration of the current transaction only. + // !! CURRENTLY NOT IMPLEMENTED !! + DUCKDB_CONFIG_OPTION_SCOPE_LOCAL = 1, + // The option is set for the current session/connection only. + DUCKDB_CONFIG_OPTION_SCOPE_SESSION = 2, + // Set the option globally for all sessions/connections. + DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL = 3, +} duckdb_config_option_scope; + +//! An enum over DuckDB's catalog entry types. +typedef enum duckdb_catalog_entry_type { + DUCKDB_CATALOG_ENTRY_TYPE_INVALID = 0, + DUCKDB_CATALOG_ENTRY_TYPE_TABLE = 1, + DUCKDB_CATALOG_ENTRY_TYPE_SCHEMA = 2, + DUCKDB_CATALOG_ENTRY_TYPE_VIEW = 3, + DUCKDB_CATALOG_ENTRY_TYPE_INDEX = 4, + DUCKDB_CATALOG_ENTRY_TYPE_PREPARED_STATEMENT = 5, + DUCKDB_CATALOG_ENTRY_TYPE_SEQUENCE = 6, + DUCKDB_CATALOG_ENTRY_TYPE_COLLATION = 7, + DUCKDB_CATALOG_ENTRY_TYPE_TYPE = 8, + DUCKDB_CATALOG_ENTRY_TYPE_DATABASE = 9, +} duckdb_catalog_entry_type; + //===--------------------------------------------------------------------===// // General type definitions //===--------------------------------------------------------------------===// @@ -548,6 +577,12 @@ typedef struct _duckdb_config { void *internal_ptr; } * duckdb_config; +//! A custom configuration option instance. Used to register custom options that can be set on a duckdb_config. +//! or by the user in SQL using `SET = `. +typedef struct _duckdb_config_option { + void *internal_ptr; +} * duckdb_config_option; + //! A logical type. //! Must be destroyed with `duckdb_destroy_logical_type`. typedef struct _duckdb_logical_type { @@ -699,6 +734,47 @@ typedef void (*duckdb_table_function_init_t)(duckdb_init_info info); //! The function to generate an output chunk during table function execution. typedef void (*duckdb_table_function_t)(duckdb_function_info info, duckdb_data_chunk output); +//===--------------------------------------------------------------------===// +// Copy function types +//===--------------------------------------------------------------------===// + +//! A COPY function. Must be destroyed with `duckdb_destroy_copy_function`. +typedef struct _duckdb_copy_function { + void *internal_ptr; +} * duckdb_copy_function; + +//! Info for the bind function of a COPY function. +typedef struct _duckdb_copy_function_bind_info { + void *internal_ptr; +} * duckdb_copy_function_bind_info; + +//! Info for the global initialization function of a COPY function. +typedef struct _duckdb_copy_function_global_init_info { + void *internal_ptr; +} * duckdb_copy_function_global_init_info; + +//! Info for the sink function of a COPY function. +typedef struct _duckdb_copy_function_sink_info { + void *internal_ptr; +} * duckdb_copy_function_sink_info; + +//! Info for the finalize function of a COPY function. +typedef struct _duckdb_copy_function_finalize_info { + void *internal_ptr; +} * duckdb_copy_function_finalize_info; + +//! The bind function to use when binding a COPY ... TO function. +typedef void (*duckdb_copy_function_bind_t)(duckdb_copy_function_bind_info info); + +//! The initialization function to use when initializing a COPY ... TO function. +typedef void (*duckdb_copy_function_global_init_t)(duckdb_copy_function_global_init_info info); + +//! The function to sink an input chunk into during execution of a COPY ... TO function. +typedef void (*duckdb_copy_function_sink_t)(duckdb_copy_function_sink_info info, duckdb_data_chunk input); + +//! The function to finalize the COPY ... TO function execution. +typedef void (*duckdb_copy_function_finalize_t)(duckdb_copy_function_finalize_info info); + //===--------------------------------------------------------------------===// // Cast types //===--------------------------------------------------------------------===// @@ -786,6 +862,35 @@ typedef struct _duckdb_file_handle { void *internal_ptr; } * duckdb_file_handle; +//===--------------------------------------------------------------------===// +// Catalog Interface +//===--------------------------------------------------------------------===// + +//! A handle to a database catalog. +//! Must be destroyed with `duckdb_destroy_catalog`. +typedef struct _duckdb_catalog { + void *internal_ptr; +} * duckdb_catalog; + +//! A handle to a catalog entry (e.g., table, view, index, etc.). +//! Must be destroyed with `duckdb_destroy_catalog_entry`. +typedef struct _duckdb_catalog_entry { + void *internal_ptr; +} * duckdb_catalog_entry; + +//===--------------------------------------------------------------------===// +// Logging Types +//===--------------------------------------------------------------------===// + +//! Holds a log storage object. +typedef struct _duckdb_log_storage { + void *internal_ptr; +} * duckdb_log_storage; + +//! This function is missing the logging context, which will be added later. +typedef void (*duckdb_logger_write_log_entry_t)(void *extra_data, duckdb_timestamp *timestamp, const char *level, + const char *log_type, const char *log_message); + //===--------------------------------------------------------------------===// // DuckDB extension access //===--------------------------------------------------------------------===// @@ -795,7 +900,7 @@ struct duckdb_extension_access { //! Indicate that an error has occurred. void (*set_error)(duckdb_extension_info info, const char *error); //! Fetch the database on which to register the extension. - duckdb_database *(*get_database)(duckdb_extension_info info); + duckdb_database (*get_database)(duckdb_extension_info info); //! Fetch the API struct pointer. const void *(*get_api)(duckdb_extension_info info, const char *version); }; @@ -806,9 +911,12 @@ struct duckdb_extension_access { // Functions //===--------------------------------------------------------------------===// -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Open Connect -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to operate on the instance cache, databases, connections, as well as some metadata functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new database instance cache. @@ -896,10 +1004,10 @@ Interrupt running query DUCKDB_C_API void duckdb_interrupt(duckdb_connection connection); /*! -Get progress of the running query +Get the progress of the running query. -* @param connection The working connection -* @return -1 if no progress or a percentage of the progress +* @param connection The connection running the query. +* @return The query progress type containing progress information. */ DUCKDB_C_API duckdb_query_progress_type duckdb_query_progress(duckdb_connection connection); @@ -968,9 +1076,12 @@ with duckdb_destroy_value. */ DUCKDB_C_API duckdb_value duckdb_get_table_names(duckdb_connection connection, const char *query, bool qualified); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Configuration -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with a `duckdb_config`, which is the configuration parameter for opening a database. +//---------------------------------------------------------------------------------------------------------------------- /*! Initializes an empty configuration object that can be used to provide start-up options for the DuckDB instance @@ -1031,12 +1142,13 @@ Destroys the specified configuration object and de-allocates all memory allocate */ DUCKDB_C_API void duckdb_destroy_config(duckdb_config *config); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Error Data -//===--------------------------------------------------------------------===// - -// Functions that can throw DuckDB errors must return duckdb_error_data. -// Please use this interface for all new functions, as it deprecates all previous error handling approaches. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to operate on `duckdb_error_data`, which contains, for example, the error type and message. Please use this +// interface for all new C API functions, as it supersedes previous error handling approaches. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates duckdb_error_data. @@ -1079,9 +1191,12 @@ Returns whether the error data contains an error or not. */ DUCKDB_C_API bool duckdb_error_data_has_error(duckdb_error_data error_data); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Query Execution -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to obtain a `duckdb_result` and to retrieve metadata from it. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes a SQL query within a connection and stores the full (materialized) result in the out_result pointer. @@ -1254,10 +1369,6 @@ Returns the result error type contained within the result. The error is only set */ DUCKDB_C_API duckdb_error_type duckdb_result_error_type(duckdb_result *result); -//===--------------------------------------------------------------------===// -// Result Functions -//===--------------------------------------------------------------------===// - #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1310,14 +1421,20 @@ Returns the return_type of the given result, or DUCKDB_RETURN_TYPE_INVALID on er */ DUCKDB_C_API duckdb_result_type duckdb_result_return_type(duckdb_result result); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Safe Fetch Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Deprecated functions to interact with a `duckdb_result`. +// +// DEPRECATION NOTICE: +// This function group is deprecated and scheduled for removal. +// +// USE INSTEAD: +// To access the values in a result, use `duckdb_fetch_chunk` repeatedly. For each chunk, use the `duckdb_data_chunk` +// interface to access any columns and their values. +//---------------------------------------------------------------------------------------------------------------------- -// These functions will perform conversions if necessary. -// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. -// Note that these functions are slow since they perform bounds checking and conversion -// For fast access of values prefer using `duckdb_result_get_chunk` #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -1446,8 +1563,7 @@ DUCKDB_C_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_ DUCKDB_C_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null -bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The text value at the specified location as a null-terminated string, or nullptr if the value cannot be converted. The result must be freed with `duckdb_free`. @@ -1457,16 +1573,12 @@ DUCKDB_C_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. -No support for nested types, and for other complex types. -The resulting field "string.data" must be freed with `duckdb_free.` - * @return The string value at the specified location. Attempts to cast the result value to string. */ DUCKDB_C_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -1476,8 +1588,8 @@ The result must NOT be freed. DUCKDB_C_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); /*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. +**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. + * @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. If the column is NOT a VARCHAR column this function will return NULL. @@ -1502,9 +1614,12 @@ DUCKDB_C_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t r #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Generic and `duckdb_string_t` helper functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Allocate `size` bytes of memory using the duckdb internal malloc function. Any memory allocated in this manner @@ -1554,9 +1669,13 @@ Get a pointer to the string data of a string_t */ DUCKDB_C_API const char *duckdb_string_t_data(duckdb_string_t *string); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Date Time Timestamp Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_[date, time, time_tz, timestamp]`. +// `duckdb_is_finite_timestamp[_s, _ms, _ns]` helper functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Decompose a `duckdb_date` object into year, month and date (stored as `duckdb_date_struct`). @@ -1664,9 +1783,12 @@ Test a `duckdb_timestamp_ns` to see if it is a finite value. */ DUCKDB_C_API bool duckdb_is_finite_timestamp_ns(duckdb_timestamp_ns ts); -//===--------------------------------------------------------------------===// -// Hugeint Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// Hugeint and Uhugeint Helpers +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_[hugeint, uhugeint]`. +//---------------------------------------------------------------------------------------------------------------------- /*! Converts a duckdb_hugeint object (as obtained from a `DUCKDB_TYPE_HUGEINT` column) into a double. @@ -1686,10 +1808,6 @@ If the conversion fails because the double value is too big the result will be 0 */ DUCKDB_C_API duckdb_hugeint duckdb_double_to_hugeint(double val); -//===--------------------------------------------------------------------===// -// Unsigned Hugeint Helpers -//===--------------------------------------------------------------------===// - /*! Converts a duckdb_uhugeint object (as obtained from a `DUCKDB_TYPE_UHUGEINT` column) into a double. @@ -1708,9 +1826,12 @@ If the conversion fails because the double value is too big the result will be 0 */ DUCKDB_C_API duckdb_uhugeint duckdb_double_to_uhugeint(double val); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Decimal Helpers -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to `duckdb_decimal`. +//---------------------------------------------------------------------------------------------------------------------- /*! Converts a double value to a duckdb_decimal object. @@ -1730,19 +1851,21 @@ Converts a duckdb_decimal object (as obtained from a `DUCKDB_TYPE_DECIMAL` colum */ DUCKDB_C_API double duckdb_decimal_to_double(duckdb_decimal val); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Prepared Statements -//===--------------------------------------------------------------------===// - -// A prepared statement is a parameterized query that allows you to bind parameters to it. -// * This is useful to easily supply parameters to functions and avoid SQL injection attacks. -// * This is useful to speed up queries that you will execute several times with different parameters. -// Because the query will only be parsed, bound, optimized and planned once during the prepare stage, -// rather than once per execution. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// A prepared statement is a parameterized query, and you can bind parameters to it. Prepared statements are commonly +// used to easily supply parameters to functions and avoid SQL injection attacks. They also speed up queries that are +// executed repeatedly with different parameters. That is because the query is only parsed, bound, optimized and planned +// once during the prepare stage, rather than once per execution, if it is possible to resolve all parameter types. +// // For example: -// SELECT * FROM tbl WHERE id=? +// SELECT * FROM tbl WHERE id = ? // Or a query with multiple parameters: -// SELECT * FROM tbl WHERE id=$1 OR name=$2 +// SELECT * FROM tbl WHERE id = $1 OR name = $2 +//---------------------------------------------------------------------------------------------------------------------- + /*! Create a prepared statement object from a query. @@ -1881,9 +2004,13 @@ Returns `DUCKDB_TYPE_INVALID` if the column is out of range. DUCKDB_C_API duckdb_type duckdb_prepared_statement_column_type(duckdb_prepared_statement prepared_statement, idx_t col_idx); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Bind Values to Prepared Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to bind values to prepared statements. Try to use `duckdb_bind_value` and the `duckdb_create_...` interface +// for all types. +//---------------------------------------------------------------------------------------------------------------------- /*! Binds a value to the prepared statement at the specified index. @@ -2026,9 +2153,12 @@ Binds a NULL value to the prepared statement at the specified index. */ DUCKDB_C_API duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Execute Prepared Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to execute a prepared statement. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes the prepared statement with the given bound parameters, and returns a materialized query result. @@ -2066,11 +2196,14 @@ DUCKDB_C_API duckdb_state duckdb_execute_prepared_streaming(duckdb_prepared_stat #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Extract Statements -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// A query string can be extracted into multiple SQL statements. Each statement should be prepared and executed +// separately. +//---------------------------------------------------------------------------------------------------------------------- -// A query string can be extracted into multiple SQL statements. Each statement can be prepared and executed separately. /*! Extract all statements from a query. Note that after calling `duckdb_extract_statements`, the extracted statements should always be destroyed using @@ -2119,9 +2252,12 @@ De-allocates all memory allocated for the extracted statements. */ DUCKDB_C_API void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Pending Result Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with a pending result. First, prepare a pending result, and then execute it. +//---------------------------------------------------------------------------------------------------------------------- /*! Executes the prepared statement with the given bound parameters, and returns a pending result. @@ -2224,9 +2360,14 @@ DUCKDB_PENDING_RESULT_READY, this function will return true. */ DUCKDB_C_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Value Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create a `duckdb_value` for each of DuckDB's supported data types, and to access the contents of a +// `duckdb_value`. The `duckdb_value` wrapper allows handling of primitive and arbitrarily (nested) types through the +// same interface. +//---------------------------------------------------------------------------------------------------------------------- /*! Destroys the value and de-allocates all memory allocated for that type. @@ -2869,9 +3010,12 @@ Returns the SQL string representation of the given value. */ DUCKDB_C_API char *duckdb_value_to_string(duckdb_value value); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Logical Type Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and interact with `duckdb_logical_type`. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a `duckdb_logical_type` from a primitive type. @@ -3159,9 +3303,17 @@ The type must have an alias DUCKDB_C_API duckdb_state duckdb_register_logical_type(duckdb_connection con, duckdb_logical_type type, duckdb_create_type_info info); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Data Chunk Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_data_chunk`. Data chunks pass through the different operators of DuckDB's +// execution engine, when, e.g., executing a scalar function. Additionally, a query result is composed of a sequence of +// data chunks. +// +// A data chunk contains a number of vectors, which, in turn, contain data in a columnar format. For the query result, +// the vectors are the result columns, and they contain the query result for each row. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates an empty data chunk with the specified column types. @@ -3224,9 +3376,13 @@ Sets the current number of tuples in a data chunk. */ DUCKDB_C_API void duckdb_data_chunk_set_size(duckdb_data_chunk chunk, idx_t size); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Vector Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_vector`. A vector typically (but not always) lives in a data chunk and contains a +// subset of the rows of a column. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a flat vector. Must be destroyed with `duckdb_destroy_vector`. @@ -3336,23 +3492,26 @@ Returns the size of the child vector of the list. DUCKDB_C_API idx_t duckdb_list_vector_get_size(duckdb_vector vector); /*! -Sets the total size of the underlying child-vector of a list vector. +Sets the size of the underlying child-vector of a list vector. +Note that this does NOT reserve the memory in the child buffer, +and that it is possible to set a size exceeding the capacity. +To set the capacity, use `duckdb_list_vector_reserve`. * @param vector The list vector. * @param size The size of the child list. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size); /*! -Sets the total capacity of the underlying child-vector of a list. - -After calling this method, you must call `duckdb_vector_get_validity` and `duckdb_vector_get_data` to obtain current -data and validity pointers +Sets the capacity of the underlying child-vector of a list vector. +We increment to the next power of two, based on the required capacity. +Thus, the capacity might not match the size of the list (capacity >= size), +which is set via `duckdb_list_vector_set_size`. * @param vector The list vector. -* @param required_capacity the total capacity to reserve. -* @return The duckdb state. Returns DuckDBError if the vector is nullptr. +* @param required_capacity The child buffer capacity to reserve. +* @return The duckdb state. Returns DuckDBError, if the vector is nullptr. */ DUCKDB_C_API duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity); @@ -3419,9 +3578,13 @@ Changes `to_vector` to reference `from_vector. After, the vectors share ownershi */ DUCKDB_C_API void duckdb_vector_reference_vector(duckdb_vector to_vector, duckdb_vector from_vector); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Validity Mask Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with the validity mask of a vector. The validity mask is a bitmask determining whether a row in +// a vector is `NULL`, or not. +//---------------------------------------------------------------------------------------------------------------------- /*! Returns whether or not a row is valid (i.e. not NULL) in the given validity mask. @@ -3464,9 +3627,14 @@ Equivalent to `duckdb_validity_set_row_validity` with valid set to true. */ DUCKDB_C_API void duckdb_validity_set_row_valid(uint64_t *validity, idx_t row); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Scalar Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom scalar functions. Scalar functions take one or more input +// parameters, and return a single output parameter. Consider using a table function, if your scalar function does not +// take any input parameters. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty scalar function. @@ -3702,9 +3870,14 @@ Returns the input argument at index of the scalar function. */ DUCKDB_C_API duckdb_expression duckdb_scalar_function_bind_get_argument(duckdb_bind_info info, idx_t index); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Selection Vector Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with `duckdb_selection_vector`. Selection vectors define a selection on top of a vector. Lets +// say that a filter filters out all `VARCHAR`-rows containing `hello`. Then, instead of creating a full new copy of the +// filtered-out data, it is possible to use a selection vector only selecting the rows satisfying the filter. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new selection vector of size `size`. @@ -3730,9 +3903,13 @@ Access the data pointer of a selection vector. */ DUCKDB_C_API sel_t *duckdb_selection_vector_get_data_ptr(duckdb_selection_vector sel); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Aggregate Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom aggregate functions. Aggregate functions aggregate the values of a +// column into an output value. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty aggregate function. @@ -3887,9 +4064,13 @@ If the set is incomplete or a function with this name already exists DuckDBError DUCKDB_C_API duckdb_state duckdb_register_aggregate_function_set(duckdb_connection con, duckdb_aggregate_function_set set); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom table functions. Table functions take one or more input parameters, +// and return one or more output parameters. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new empty table function. @@ -4005,9 +4186,13 @@ If the function is incomplete or a function with this name already exists DuckDB */ DUCKDB_C_API duckdb_state duckdb_register_table_function(duckdb_connection con, duckdb_table_function function); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function Bind -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the bind-phase of a table function. The bind-phase happens once before the execution of the +// table function. It is useful to, e.g., set up any read-only information for the different threads during execution. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4090,9 +4275,13 @@ Report that an error has occurred while calling bind on a table function. */ DUCKDB_C_API void duckdb_bind_set_error(duckdb_bind_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function Init -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the init-phase of a table function. The init-phase happens once for each thread and +// initializes thread-local information prior to execution. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4159,9 +4348,13 @@ Report that an error has occurred while calling init. */ DUCKDB_C_API void duckdb_init_set_error(duckdb_init_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Function -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to implement the execution callback of a table function. The execution callback (i.e., the main function) +// produces a data chunk output based on a data chunk input, and has access to both the bind and init data. +//---------------------------------------------------------------------------------------------------------------------- /*! Retrieves the extra info of the function as set in `duckdb_table_function_set_extra_info`. @@ -4206,9 +4399,13 @@ Report that an error has occurred while executing the function. */ DUCKDB_C_API void duckdb_function_set_error(duckdb_function_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Replacement Scans -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register a custom replacement scan. A replacement scan is a callback replacing a +// scan of a table that does not exist in the catalog. +//---------------------------------------------------------------------------------------------------------------------- /*! Add a replacement scan definition to the specified database. @@ -4247,9 +4444,12 @@ Report that an error has occurred while executing the replacement scan. */ DUCKDB_C_API void duckdb_replacement_scan_set_error(duckdb_replacement_scan_info info, const char *error); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Profiling Info -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to access the post-execution profiling information of a query. Only available, if profiling is enabled. +//---------------------------------------------------------------------------------------------------------------------- /*! Returns the root node of the profiling information. Returns nullptr, if profiling is not enabled. @@ -4296,23 +4496,17 @@ Returns the child node at the specified index. */ DUCKDB_C_API duckdb_profiling_info duckdb_profiling_info_get_child(duckdb_profiling_info info, idx_t index); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Appender -//===--------------------------------------------------------------------===// - -// Appenders are the most efficient way of loading data into DuckDB from within the C API. -// They are recommended for fast data loading as they perform better than prepared statements or individual `INSERT -// INTO` statements. - -// Appends are possible in row-wise format, and by appending entire data chunks. - -// Row-wise: for every column, a `duckdb_append_[type]` call should be made. After finishing all appends to a row, call -// `duckdb_appender_end_row`. - -// Chunk-wise: Consecutively call `duckdb_append_data_chunk` until all chunks have been appended. - -// After all data has been appended, call `duckdb_appender_close` to finalize the appender followed by -// `duckdb_appender_destroy` to clean up the memory. +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Appenders are the most efficient way of bulk-loading data into DuckDB. They are recommended for fast data loading as +// they perform better than prepared statements or individual `INSERT INTO` statements. Appends are possible in row-wise +// format, and by appending entire data chunks. Try to use chunk-wise appends via `duckdb_append_data_chunk` to ensure +// support for all of DuckDBs data types. Chunk-wise appends consecutively call `duckdb_append_data_chunk` until all +// chunks have been appended. Afterward, call `duckdb_appender_destroy` flush any outstanding data and to destroy the +// appender instance. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates an appender object. @@ -4421,6 +4615,15 @@ duckdb_appender_destroy to destroy the invalidated appender. */ DUCKDB_C_API duckdb_state duckdb_appender_flush(duckdb_appender appender); +/*! +Clears all buffered data from the appender without flushing it to the table. This discards any data that has been +appended but not yet written. The appender can continue to be used after clearing. + +* @param appender The appender to clear. +* @return `DuckDBSuccess` on success or `DuckDBError` on failure. +*/ +DUCKDB_C_API duckdb_state duckdb_appender_clear(duckdb_appender appender); + /*! Closes the appender by flushing all intermediate states and closing it for further appends. If flushing the data triggers a constraint violation or any other error, then all data is invalidated, and this function returns DuckDBError. @@ -4616,9 +4819,12 @@ Appends a pre-filled data chunk to the specified appender. */ DUCKDB_C_API duckdb_state duckdb_append_data_chunk(duckdb_appender appender, duckdb_data_chunk chunk); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Table Description -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and access a `duckdb_table_description` instance. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a table description object. Note that `duckdb_table_description_destroy` should always be called on the @@ -4675,6 +4881,14 @@ Check if the column at 'index' index of the table has a DEFAULT expression. */ DUCKDB_C_API duckdb_state duckdb_column_has_default(duckdb_table_description table_description, idx_t index, bool *out); +/*! +Return the number of columns of the described table. + +* @param table_description The table_description to query. +* @return The column count. +*/ +DUCKDB_C_API idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description); + /*! Obtain the column name at 'index'. The out result must be destroyed with `duckdb_free`. @@ -4685,9 +4899,23 @@ The out result must be destroyed with `duckdb_free`. */ DUCKDB_C_API char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index); -//===--------------------------------------------------------------------===// +/*! +Obtain the column type at 'index'. +The return value must be destroyed with `duckdb_destroy_logical_type`. + +* @param table_description The table_description to query. +* @param index The index of the column to query. +* @return The column type. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, + idx_t index); + +//---------------------------------------------------------------------------------------------------------------------- // Arrow Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to convert from and to Arrow. +//---------------------------------------------------------------------------------------------------------------------- /*! Transforms a DuckDB Schema into an Arrow Schema @@ -4927,9 +5155,12 @@ DUCKDB_C_API duckdb_state duckdb_arrow_array_scan(duckdb_connection connection, #endif -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Threading Information -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and execute tasks. +//---------------------------------------------------------------------------------------------------------------------- /*! Execute DuckDB tasks on this thread. @@ -5008,9 +5239,12 @@ Returns true if the execution of the current query is finished. */ DUCKDB_C_API bool duckdb_execution_is_finished(duckdb_connection con); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Streaming Result Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to stream a `duckdb_result`. Call `duckdb_fetch_chunk` until the result is exhausted. +//---------------------------------------------------------------------------------------------------------------------- #ifndef DUCKDB_API_NO_DEPRECATED /*! @@ -5047,9 +5281,12 @@ It is not known beforehand how many chunks will be returned by this result. */ DUCKDB_C_API duckdb_data_chunk duckdb_fetch_chunk(duckdb_result result); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Cast Functions -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, execute, and register custom cast functions. +//---------------------------------------------------------------------------------------------------------------------- /*! Creates a new cast function object. @@ -5153,9 +5390,13 @@ Destroys the cast function object. */ DUCKDB_C_API void duckdb_destroy_cast_function(duckdb_cast_function *cast_function); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // Expression Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create and access expressions. Expressions are widespread in DuckDB, especially during query planning. +// E.g., scalar function parameters are expressions, and can be inspected during the bind-phase. +//---------------------------------------------------------------------------------------------------------------------- /*! Destroys the expression and de-allocates its memory. @@ -5191,9 +5432,13 @@ Folds an expression creating a folded value. DUCKDB_C_API duckdb_error_data duckdb_expression_fold(duckdb_client_context context, duckdb_expression expr, duckdb_value *out_value); -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- // File System Interface -//===--------------------------------------------------------------------===// +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to access the file system of a connection and to interact with file handles. File handle instances to files +// allow operations such as reading, writing, and seeking in a file. +//---------------------------------------------------------------------------------------------------------------------- /*! Get a file system instance associated with the given client context. @@ -5335,6 +5580,578 @@ Closes the given file handle. */ DUCKDB_C_API duckdb_state duckdb_file_handle_close(duckdb_file_handle file_handle); +//---------------------------------------------------------------------------------------------------------------------- +// Config Options Interface +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to create, configure, and register custom configuration options. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Creates a configuration option instance. + +* @return The resulting configuration option instance. Must be destroyed with `duckdb_destroy_config_option`. +*/ +DUCKDB_C_API duckdb_config_option duckdb_create_config_option(); + +/*! +Destroys the given configuration option instance. +* @param option The configuration option instance to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_config_option(duckdb_config_option *option); + +/*! +Sets the name of the configuration option. + +* @param option The configuration option instance. +* @param name The name to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_name(duckdb_config_option option, const char *name); + +/*! +Sets the type of the configuration option. + +* @param option The configuration option instance. +* @param type The type to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_type(duckdb_config_option option, duckdb_logical_type type); + +/*! +Sets the default value of the configuration option. +If the type of this option has already been set with `duckdb_config_option_set_type`, the value is cast to the type. +Otherwise, the type is inferred from the value. + +* @param option The configuration option instance. +* @param default_value The default value to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_default_value(duckdb_config_option option, duckdb_value default_value); + +/*! +Sets the default scope of the configuration option. +If not set, this defaults to `DUCKDB_CONFIG_OPTION_SCOPE_SESSION`. + +* @param option The configuration option instance. +* @param default_scope The default scope to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_default_scope(duckdb_config_option option, + duckdb_config_option_scope default_scope); + +/*! +Sets the description of the configuration option. + +* @param option The configuration option instance. +* @param description The description to set. +*/ +DUCKDB_C_API void duckdb_config_option_set_description(duckdb_config_option option, const char *description); + +/*! +Registers the given configuration option on the specified connection. + +* @param connection The connection to register the option on. +* @param option The configuration option instance to register. +* @return A duckdb_state indicating success or failure. +*/ +DUCKDB_C_API duckdb_state duckdb_register_config_option(duckdb_connection connection, duckdb_config_option option); + +/*! +Retrieves the value of a configuration option by name from the given client context. + +* @param context The client context. +* @param name The name of the configuration option to retrieve. +* @param out_scope Output parameter to optionally store the scope that the configuration option was retrieved from. +If this is `nullptr`, the scope is not returned. +If the requested option does not exist the scope is set to `DUCKDB_CONFIG_OPTION_SCOPE_INVALID`. +* @return The value of the configuration option. Returns `nullptr` if the option does not exist. +*/ +DUCKDB_C_API duckdb_value duckdb_client_context_get_config_option(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); + +//---------------------------------------------------------------------------------------------------------------------- +// Copy Functions +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to copy data from and to external file formats. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Creates a new empty copy function. + +The return value must be destroyed with `duckdb_destroy_copy_function`. + +* @return The copy function object. +*/ +DUCKDB_C_API duckdb_copy_function duckdb_create_copy_function(); + +/*! +Sets the name of the copy function. + +* @param copy_function The copy function +* @param name The name to set +*/ +DUCKDB_C_API void duckdb_copy_function_set_name(duckdb_copy_function copy_function, const char *name); + +/*! +Sets the extra info pointer of the copy function, which can be used to store arbitrary data. + +* @param copy_function The copy function +* @param extra_info The extra info pointer +* @param destructor A destructor function to call to destroy the extra info +*/ +DUCKDB_C_API void duckdb_copy_function_set_extra_info(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + +/*! +Registers the given copy function on the database connection under the specified name. + +* @param connection The database connection +* @param copy_function The copy function to register +*/ +DUCKDB_C_API duckdb_state duckdb_register_copy_function(duckdb_connection connection, + duckdb_copy_function copy_function); + +/*! +Destroys the given copy function object. +* @param copy_function The copy function to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_copy_function(duckdb_copy_function *copy_function); + +/*! +Sets the bind function of the copy function, to use when binding `COPY ... TO`. + +* @param bind The bind function +*/ +DUCKDB_C_API void duckdb_copy_function_set_bind(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + +/*! +Report that an error occurred during the binding-phase of a `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_bind_set_error(duckdb_copy_function_bind_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The bind info provided to the bind function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_bind_get_extra_info(duckdb_copy_function_bind_info info); + +/*! +Retrieves the client context of the current connection binding the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The bind info provided to the bind function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context duckdb_copy_function_bind_get_client_context(duckdb_copy_function_bind_info info); + +/*! +Retrieves the number of columns that will be provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @return The number of columns. +*/ +DUCKDB_C_API idx_t duckdb_copy_function_bind_get_column_count(duckdb_copy_function_bind_info info); + +/*! +Retrieves the type of a column that will be provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the column to retrieve the type for +* @return The type of the column. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_copy_function_bind_get_column_type(duckdb_copy_function_bind_info info, + idx_t col_idx); + +/*! +Retrieves all values for the given options provided to the `COPY ... TO` function. + +* @param info The bind info provided to the bind function +* @return A STRUCT value containing all options as fields. Must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_C_API duckdb_value duckdb_copy_function_bind_get_options(duckdb_copy_function_bind_info info); + +/*! +Sets the bind data of the copy function, to be provided to the init, sink and finalize functions. + +* @param info The bind info provided to the bind function +* @param bind_data The bind data pointer +* @param destructor A destructor function to call to destroy the bind data +*/ +DUCKDB_C_API void duckdb_copy_function_bind_set_bind_data(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + +/*! +Sets the initialization function of the copy function, called right before executing `COPY ... TO`. + +* @param init The init function +*/ +DUCKDB_C_API void duckdb_copy_function_set_global_init(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + +/*! +Report that an error occurred during the initialization-phase of a `COPY ... TO` function. + +* @param info The init info provided to the init function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_global_init_set_error(duckdb_copy_function_global_init_info info, + const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The init info provided to the init function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_global_init_get_extra_info(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the client context of the current connection initializing the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The init info provided to the init function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context +duckdb_copy_function_global_init_get_client_context(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The init info provided to the init function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_global_init_get_bind_data(duckdb_copy_function_global_init_info info); + +/*! +Retrieves the file path provided to the `COPY ... TO` function. + +Lives for the duration of the initialization callback, must not be destroyed. + +* @param info The init info provided to the init function +* @return The file path. +*/ +DUCKDB_C_API const char *duckdb_copy_function_global_init_get_file_path(duckdb_copy_function_global_init_info info); + +/*! +Sets the global state of the copy function, to be provided to all subsequent local init, sink and finalize functions. + +* @param info The init info provided to the init function +* @param global_state The global state pointer +* @param destructor A destructor function to call to destroy the global state +*/ +DUCKDB_C_API void duckdb_copy_function_global_init_set_global_state(duckdb_copy_function_global_init_info info, + void *global_state, + duckdb_delete_callback_t destructor); + +/*! +Sets the sink function of the copy function, called during `COPY ... TO`. + +* @param function The sink function +*/ +DUCKDB_C_API void duckdb_copy_function_set_sink(duckdb_copy_function copy_function, + duckdb_copy_function_sink_t function); + +/*! +Report that an error occurred during the sink-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_sink_set_error(duckdb_copy_function_sink_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The sink info provided to the sink function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_extra_info(duckdb_copy_function_sink_info info); + +/*! +Retrieves the client context of the current connection during the sink-phase of the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The sink info provided to the sink function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context duckdb_copy_function_sink_get_client_context(duckdb_copy_function_sink_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_bind_data(duckdb_copy_function_sink_info info); + +/*! +Retrieves the global state provided during the init-phase of a `COPY ... TO` function. + +* @param info The sink info provided to the sink function +* @return The global state pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_sink_get_global_state(duckdb_copy_function_sink_info info); + +/*! +Sets the finalize function of the copy function, called at the end of `COPY ... TO`. + +* @param finalize The finalize function +*/ +DUCKDB_C_API void duckdb_copy_function_set_finalize(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + +/*! +Report that an error occurred during the finalize-phase of a `COPY ... TO` function + +* @param info The finalize info provided to the finalize function +* @param error The error message +*/ +DUCKDB_C_API void duckdb_copy_function_finalize_set_error(duckdb_copy_function_finalize_info info, const char *error); + +/*! +Retrieves the extra info pointer of the copy function. + +* @param info The finalize info provided to the finalize function +* @return The extra info pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_extra_info(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the client context of the current connection during the finalize-phase of the `COPY ... TO` function. + +Must be destroyed with `duckdb_destroy_client_context` + +* @param info The finalize info provided to the finalize function +* @return The client context. +*/ +DUCKDB_C_API duckdb_client_context +duckdb_copy_function_finalize_get_client_context(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the bind data provided during the binding-phase of a `COPY ... TO` function. + +* @param info The finalize info provided to the finalize function +* @return The bind data pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_bind_data(duckdb_copy_function_finalize_info info); + +/*! +Retrieves the global state provided during the init-phase of a `COPY ... TO` function. + +* @param info The finalize info provided to the finalize function +* @return The global state pointer. +*/ +DUCKDB_C_API void *duckdb_copy_function_finalize_get_global_state(duckdb_copy_function_finalize_info info); + +/*! +Sets the table function to use when executing a `COPY ... FROM (...)` statement with this copy function. + +The table function must have a `duckdb_table_function_bind_t`, `duckdb_table_function_init_t` and +`duckdb_table_function_t` set. + +The table function must take a single VARCHAR parameter (the file path). + +Options passed to the `COPY ... FROM (...)` statement are forwarded as named parameters to the table function. + +Since `COPY ... FROM` copies into an already existing table, the table function should not define its own result columns +using `duckdb_bind_add_result_column` when binding . Instead use `duckdb_table_function_bind_get_result_column_count` +and related functions in the bind callback of the table function to retrieve the schema of the target table of the `COPY +... FROM` statement. + +* @param copy_function The copy function +* @param table_function The table function to use for `COPY ... FROM` +*/ +DUCKDB_C_API void duckdb_copy_function_set_copy_from_function(duckdb_copy_function copy_function, + duckdb_table_function table_function); + +/*! +Retrieves the number of result columns of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the number of columns in the +target table at the start of the bind callback. + +* @param info The bind info provided to the bind function +* @return The number of result columns. +*/ +DUCKDB_C_API idx_t duckdb_table_function_bind_get_result_column_count(duckdb_bind_info info); + +/*! +Retrieves the name of a result column of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the names of the columns in +the target table at the start of the bind callback. + +The result is valid for the duration of the bind callback or until the next call to `duckdb_bind_add_result_column`, so +it must not be destroyed. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the result column to retrieve the name for +* @return The name of the result column. +*/ +DUCKDB_C_API const char *duckdb_table_function_bind_get_result_column_name(duckdb_bind_info info, idx_t col_idx); + +/*! +Retrieves the type of a result column of a table function. + +If the table function is used in a `COPY ... FROM` statement, this can be used to retrieve the types of the columns in +the target table at the start of the bind callback. + +The result must be destroyed with `duckdb_destroy_logical_type`. + +* @param info The bind info provided to the bind function +* @param col_idx The index of the result column to retrieve the type for +* @return The type of the result column. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_function_bind_get_result_column_type(duckdb_bind_info info, + idx_t col_idx); + +//---------------------------------------------------------------------------------------------------------------------- +// Catalog Interface +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions to interact with database catalogs and catalog entries. +// You will most likely not need this API for typical usage of DuckDB as SQL is the preferred way to interact with the +// database, but this interface can be useful for advanced extensions that need to inspect the state of the catalog from +// inside a running query. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Retrieve a database catalog instance by name. +This function can only be called from within the context of an active transaction, e.g. during execution of a registered +function callback. Otherwise returns `nullptr`. +* @param context The client context. +* @param catalog_name The name of the catalog. +* @return The resulting catalog instance, or `nullptr` if called from outside an active transaction or if a catalog with +the specified name does not exist. Must be destroyed with `duckdb_destroy_catalog` +*/ +DUCKDB_C_API duckdb_catalog duckdb_client_context_get_catalog(duckdb_client_context context, const char *catalog_name); + +/*! +Retrieve the "type name" of the given catalog. +E.g. for a DuckDB database, this returns 'duckdb'. +The returned string is owned by the catalog and remains valid until the catalog is destroyed. + +* @param catalog The catalog. +* @return The type name of the catalog. +*/ +DUCKDB_C_API const char *duckdb_catalog_get_type_name(duckdb_catalog catalog); + +/*! +Retrieve a catalog entry from the given catalog by type, schema name and entry name. +The returned catalog entry remains valid for the duration of the current transaction. + +* @param catalog The catalog. +* @param context The client context. +* @param entry_type The type of the catalog entry to retrieve. +* @param schema_name The schema name of the catalog entry. +* @param entry_name The name of the catalog entry. +* @return The resulting catalog entry, or `nullptr` if no such entry exists. Must be destroyed with +`duckdb_destroy_catalog_entry`. Remains valid for the duration of the current transaction. +*/ +DUCKDB_C_API duckdb_catalog_entry duckdb_catalog_get_entry(duckdb_catalog catalog, duckdb_client_context context, + duckdb_catalog_entry_type entry_type, + const char *schema_name, const char *entry_name); + +/*! +Destroys the given catalog instance. + +Note that this does not actually "drop" the contents of the catalog; it merely frees the C API handle. + +* @param catalog The catalog instance to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_catalog(duckdb_catalog *catalog); + +/*! +Get the type of the given catalog entry. + +* @param entry The catalog entry. +* @return The type of the catalog entry. +*/ +DUCKDB_C_API duckdb_catalog_entry_type duckdb_catalog_entry_get_type(duckdb_catalog_entry entry); + +/*! +Get the name of the given catalog entry. + +* @param entry The catalog entry. +* @return The name of the catalog entry. The returned string is owned by the catalog entry and remains valid until the +catalog entry is destroyed. +*/ +DUCKDB_C_API const char *duckdb_catalog_entry_get_name(duckdb_catalog_entry entry); + +/*! +Destroys the given catalog entry instance. + +Note that this does not actually "drop" the catalog entry from the database catalog; it merely frees the C API handle. + +* @param entry The catalog entry instance to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_catalog_entry(duckdb_catalog_entry *entry); + +//---------------------------------------------------------------------------------------------------------------------- +// Logging +//---------------------------------------------------------------------------------------------------------------------- +// DESCRIPTION: +// Functions exposing the log storage, which allows the configuration of a custom logger. This API is not yet ready to +// be stabilized. +//---------------------------------------------------------------------------------------------------------------------- + +/*! +Creates a new log storage object. + +* @return A log storage object. Must be destroyed with `duckdb_destroy_log_storage`. +*/ +DUCKDB_C_API duckdb_log_storage duckdb_create_log_storage(); + +/*! +Destroys a log storage object. + +* @param log_storage The log storage object to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_log_storage(duckdb_log_storage *log_storage); + +/*! +Sets the callback function for writing log entries. + +* @param log_storage The log storage object. +* @param function The function to call. +*/ +DUCKDB_C_API void duckdb_log_storage_set_write_log_entry(duckdb_log_storage log_storage, + duckdb_logger_write_log_entry_t function); + +/*! +Sets the extra data of the custom log storage. + +* @param log_storage The log storage object. +* @param extra_data The extra data that is passed back into the callbacks. +* @param delete_callback The delete callback to call on the extra data, if any. +*/ +DUCKDB_C_API void duckdb_log_storage_set_extra_data(duckdb_log_storage log_storage, void *extra_data, + duckdb_delete_callback_t delete_callback); + +/*! +Sets the name of the log storage. + +* @param log_storage The log storage object. +* @param name The name of the log storage. +*/ +DUCKDB_C_API void duckdb_log_storage_set_name(duckdb_log_storage log_storage, const char *name); + +/*! +Registers a custom log storage for the logger. + +* @param database A database object. +* @param log_storage The log storage object. +* @return Whether the registration was successful. +*/ +DUCKDB_C_API duckdb_state duckdb_register_log_storage(duckdb_database database, duckdb_log_storage log_storage); + #endif #ifdef __cplusplus diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 0cf71fa73..7f206a43d 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -47,7 +47,7 @@ class DuckTableEntry : public TableCatalogEntry { TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; - vector GetColumnSegmentInfo() override; + vector GetColumnSegmentInfo(const QueryContext &context) override; TableStorageInfo GetStorageInfo(ClientContext &context) override; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index 5cab72c59..1e40319cf 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -111,7 +111,7 @@ class TableCatalogEntry : public StandardEntry { static string ColumnNamesToSQL(const ColumnList &columns); //! Returns a list of segment information for this table, if exists - virtual vector GetColumnSegmentInfo(); + virtual vector GetColumnSegmentInfo(const QueryContext &context); //! Returns the storage info of this table virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp index bbea06a7c..a333cf090 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/catalog/catalog_entry/macro_catalog_entry.hpp +// duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp // // //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp index 38e83238b..8cce35abf 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp @@ -24,6 +24,7 @@ namespace duckdb { struct AlterInfo; +struct ChangeOwnershipInfo; class ClientContext; class LogicalDependencyList; diff --git a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp index e2265c8c7..f3a71b594 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp @@ -19,7 +19,7 @@ struct DefaultType { LogicalTypeId type; }; -using builtin_type_array = std::array; +using builtin_type_array = std::array; static constexpr const builtin_type_array BUILTIN_TYPES{{ {"decimal", LogicalTypeId::DECIMAL}, @@ -97,7 +97,8 @@ static constexpr const builtin_type_array BUILTIN_TYPES{{ {"real", LogicalTypeId::FLOAT}, {"float4", LogicalTypeId::FLOAT}, {"double", LogicalTypeId::DOUBLE}, - {"float8", LogicalTypeId::DOUBLE} + {"float8", LogicalTypeId::DOUBLE}, + {"geometry", LogicalTypeId::GEOMETRY} }}; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h b/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h deleted file mode 100644 index a51eb0604..000000000 --- a/src/duckdb/src/include/duckdb/common/adbc/driver_manager.h +++ /dev/null @@ -1,82 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "duckdb/common/adbc/adbc.h" - -#ifdef __cplusplus -extern "C" { -#endif - -#ifndef ADBC_DRIVER_MANAGER_H -#define ADBC_DRIVER_MANAGER_H -/// \brief Common entry point for drivers via the driver manager. -/// -/// The driver manager can fill in default implementations of some -/// ADBC functions for drivers. Drivers must implement a minimum level -/// of functionality for this to be possible, however, and some -/// functions must be implemented by the driver. -/// -/// \param[in] driver_name An identifier for the driver (e.g. a path to a -/// shared library on Linux). -/// \param[in] entrypoint An identifier for the entrypoint (e.g. the -/// symbol to call for AdbcDriverInitFunc on Linux). -/// \param[in] version The ADBC revision to attempt to initialize. -/// \param[out] driver The table of function pointers to initialize. -/// \param[out] error An optional location to return an error message -/// if necessary. -ADBC_EXPORT -AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *driver, - struct AdbcError *error); - -/// \brief Common entry point for drivers via the driver manager. -/// -/// The driver manager can fill in default implementations of some -/// ADBC functions for drivers. Drivers must implement a minimum level -/// of functionality for this to be possible, however, and some -/// functions must be implemented by the driver. -/// -/// \param[in] init_func The entrypoint to call. -/// \param[in] version The ADBC revision to attempt to initialize. -/// \param[out] driver The table of function pointers to initialize. -/// \param[out] error An optional location to return an error message -/// if necessary. -ADBC_EXPORT -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *driver, - struct AdbcError *error); - -/// \brief Set the AdbcDriverInitFunc to use. -/// -/// This is an extension to the ADBC API. The driver manager shims -/// the AdbcDatabase* functions to allow you to specify the -/// driver/entrypoint dynamically. This function lets you set the -/// entrypoint explicitly, for applications that can dynamically -/// load drivers on their own. -ADBC_EXPORT -AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, - struct AdbcError *error); - -/// \brief Get a human-friendly description of a status code. -ADBC_EXPORT -const char *AdbcStatusCodeMessage(AdbcStatusCode code); - -#endif // ADBC_DRIVER_MANAGER_H - -#ifdef __cplusplus -} -#endif diff --git a/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp b/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp new file mode 100644 index 000000000..501d0fe37 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_containers/arena_ptr.hpp @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_containers/arena_ptr.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unique_ptr.hpp" + +namespace duckdb { + +//! Call destructor without attempting to free the memory +template +struct arena_deleter { // NOLINT: match stl case + void operator()(T *pointer) { + pointer->~T(); + } +}; + +template +using arena_ptr = unique_ptr>; + +template +using unsafe_arena_ptr = unique_ptr, false>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp b/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp new file mode 100644 index 000000000..4c05cc37b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_containers/arena_vector.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_containers/arena_vector.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arena_stl_allocator.hpp" + +namespace duckdb { + +template +using arena_vector = vector>; + +template +using unsafe_arena_vector = vector>; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp b/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp new file mode 100644 index 000000000..5f7582df6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/arena_stl_allocator.hpp @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/arena_stl_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +template +class arena_stl_allocator { // NOLINT: match stl case +public: + //! Typedefs + typedef T value_type; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + typedef value_type &reference; + typedef value_type const &const_reference; + typedef value_type *pointer; + typedef value_type const *const_pointer; + + //! Propagation traits + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + using is_always_equal = std::false_type; + + //! Rebind + template + struct rebind { + using other = arena_stl_allocator; + }; + +public: + arena_stl_allocator(ArenaAllocator &arena_allocator_p) noexcept // NOLINT: allow implicit conversion + : arena_allocator(arena_allocator_p) { + } + template + arena_stl_allocator(const arena_stl_allocator &other) noexcept // NOLINT: allow implicit conversion + : arena_allocator(other.GetAllocator()) { + } + +public: + pointer allocate(size_type n) { // NOLINT: match stl case + arena_allocator.get().AlignNext(); + return reinterpret_cast(arena_allocator.get().Allocate(n * sizeof(T))); + } + + void deallocate(pointer p, size_type n) noexcept { // NOLINT: match stl case + } + + template + void construct(U *p, Args &&...args) { // NOLINT: match stl case + ::new (p) U(std::forward(args)...); + } + + template + void destroy(U *p) noexcept { // NOLINT: match stl case + p->~U(); + } + + pointer address(reference x) const { // NOLINT: match stl case + return &x; + } + + const_pointer address(const_reference x) const { // NOLINT: match stl case + return &x; + } + + ArenaAllocator &GetAllocator() const { + return arena_allocator.get(); + } + +public: + bool operator==(const arena_stl_allocator &other) const noexcept { + return RefersToSameObject(arena_allocator, other.arena_allocator); + } + bool operator!=(const arena_stl_allocator &other) const noexcept { + return !(*this == other); + } + +private: + //! Need to use std::reference_wrapper because "reference" is already a typedef + std::reference_wrapper arena_allocator; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp index e28c002ee..60fc7659a 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp @@ -71,9 +71,10 @@ struct ArrowUUIDBlobConverter { template static TGT Operation(hugeint_t input) { // Turn into big-end - auto upper = BSwap(input.lower); + auto upper = BSwapIfLE(input.lower); // flip Upper MSD - auto lower = BSwap(static_cast(static_cast(input.upper) ^ (static_cast(1) << 63))); + auto lower = + BSwapIfLE(static_cast(static_cast(input.upper) ^ (static_cast(1) << 63))); return {static_cast(upper), static_cast(lower)}; } diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp index 811a410a5..3f736731b 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_query_result.hpp @@ -31,10 +31,6 @@ class ArrowQueryResult : public QueryResult { DUCKDB_API explicit ArrowQueryResult(ErrorData error); public: - //! Fetches a DataChunk from the query result. - //! This will consume the result (i.e. the result can only be scanned once with this function) - DUCKDB_API unique_ptr Fetch() override; - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; @@ -44,6 +40,9 @@ class ArrowQueryResult : public QueryResult { void SetArrowData(vector> arrays); idx_t BatchSize() const; +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: vector> arrays; idx_t batch_size; diff --git a/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp b/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp index 3bd89e67f..d659235d2 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/physical_arrow_collector.hpp @@ -47,7 +47,7 @@ class PhysicalArrowCollector : public PhysicalResultCollector { static PhysicalOperator &Create(ClientContext &context, PreparedStatementData &data, idx_t batch_size); SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; unique_ptr GetGlobalSinkState(ClientContext &context) const override; unique_ptr GetLocalSinkState(ExecutionContext &context) const override; SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, diff --git a/src/duckdb/src/include/duckdb/common/assert.hpp b/src/duckdb/src/include/duckdb/common/assert.hpp index dbf0744e7..4bf4b90e5 100644 --- a/src/duckdb/src/include/duckdb/common/assert.hpp +++ b/src/duckdb/src/include/duckdb/common/assert.hpp @@ -38,3 +38,6 @@ DUCKDB_API void DuckDBAssertInternal(bool condition, const char *condition_name, #define D_ASSERT_IS_ENABLED #endif + +//! Force assertion implementation, which always asserts whatever build type is used. +#define ALWAYS_ASSERT(condition) duckdb::DuckDBAssertInternal(bool(condition), #condition, __FILE__, __LINE__) diff --git a/src/duckdb/src/include/duckdb/common/bitpacking.hpp b/src/duckdb/src/include/duckdb/common/bitpacking.hpp index 06d12882a..618caf0e8 100644 --- a/src/duckdb/src/include/duckdb/common/bitpacking.hpp +++ b/src/duckdb/src/include/duckdb/common/bitpacking.hpp @@ -25,7 +25,6 @@ struct HugeIntPacker { }; class BitpackingPrimitives { - public: static constexpr const idx_t BITPACKING_ALGORITHM_GROUP_SIZE = 32; static constexpr const idx_t BITPACKING_HEADER_SIZE = sizeof(uint64_t); @@ -61,7 +60,6 @@ class BitpackingPrimitives { template inline static void UnPackBuffer(data_ptr_t dst, data_ptr_t src, idx_t count, bitpacking_width_t width, bool skip_sign_extension = false) { - for (idx_t i = 0; i < count; i += BITPACKING_ALGORITHM_GROUP_SIZE) { UnPackGroup(dst + i * sizeof(T), src + (i * width) / 8, width, skip_sign_extension); } diff --git a/src/duckdb/src/include/duckdb/common/box_renderer.hpp b/src/duckdb/src/include/duckdb/common/box_renderer.hpp index 914020249..38129f62e 100644 --- a/src/duckdb/src/include/duckdb/common/box_renderer.hpp +++ b/src/duckdb/src/include/duckdb/common/box_renderer.hpp @@ -20,7 +20,7 @@ class ColumnDataRowCollection; enum class ValueRenderAlignment { LEFT, MIDDLE, RIGHT }; enum class RenderMode : uint8_t { ROWS, COLUMNS }; -enum class ResultRenderType { LAYOUT, COLUMN_NAME, COLUMN_TYPE, VALUE, NULL_VALUE, FOOTER }; +enum class ResultRenderType { LAYOUT, COLUMN_NAME, COLUMN_TYPE, VALUE, NULL_VALUE, FOOTER, STRING_LITERAL }; class BaseResultRenderer { public: @@ -32,6 +32,9 @@ class BaseResultRenderer { virtual void RenderType(const string &text) = 0; virtual void RenderValue(const string &text, const LogicalType &type) = 0; virtual void RenderNull(const string &text, const LogicalType &type) = 0; + virtual void RenderStringLiteral(const string &text, const LogicalType &type) { + RenderValue(text, type); + } virtual void RenderFooter(const string &text) = 0; BaseResultRenderer &operator<<(char c); @@ -129,8 +132,6 @@ struct BoxRendererConfig { }; class BoxRenderer { - static const idx_t SPLIT_COLUMN; - public: explicit BoxRenderer(BoxRendererConfig config_p = BoxRendererConfig()); @@ -140,40 +141,12 @@ class BoxRenderer { BaseResultRenderer &ss); void Print(ClientContext &context, const vector &names, const ColumnDataCollection &op); + static string TryFormatLargeNumber(const string &numeric, char decimal_sep); + static string TruncateValue(const string &value, idx_t column_width, idx_t &pos, idx_t ¤t_render_width); + private: //! The configuration used for rendering BoxRendererConfig config; - -private: - void RenderValue(BaseResultRenderer &ss, const string &value, idx_t column_width, ResultRenderType render_mode, - ValueRenderAlignment alignment = ValueRenderAlignment::MIDDLE); - string RenderType(const LogicalType &type); - ValueRenderAlignment TypeAlignment(const LogicalType &type); - string GetRenderValue(BaseResultRenderer &ss, ColumnDataRowCollection &rows, idx_t c, idx_t r, - const LogicalType &type, ResultRenderType &render_mode); - list FetchRenderCollections(ClientContext &context, const ColumnDataCollection &result, - idx_t top_rows, idx_t bottom_rows); - list PivotCollections(ClientContext &context, list input, - vector &column_names, vector &result_types, - idx_t row_count); - vector ComputeRenderWidths(const vector &names, const vector &result_types, - list &collections, idx_t min_width, idx_t max_width, - vector &column_map, idx_t &total_length); - void RenderHeader(const vector &names, const vector &result_types, - const vector &column_map, const vector &widths, const vector &boundaries, - idx_t total_length, bool has_results, BaseResultRenderer &renderer); - void RenderValues(const list &collections, const vector &column_map, - const vector &widths, const vector &result_types, BaseResultRenderer &ss); - void RenderRowCount(string &row_count_str, string &readable_rows_str, string &shown_str, - const string &column_count_str, const vector &boundaries, bool has_hidden_rows, - bool has_hidden_columns, idx_t total_length, idx_t row_count, idx_t column_count, - idx_t minimum_row_length, BaseResultRenderer &ss); - - string FormatNumber(const string &input); - string ConvertRenderValue(const string &input, const LogicalType &type); - string ConvertRenderValue(const string &input); - //! Try to format a large number in a readable way (e.g. 1234567 -> 1.23 million) - string TryFormatLargeNumber(const string &numeric); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/bswap.hpp b/src/duckdb/src/include/duckdb/common/bswap.hpp index fbcafb8f6..a1434da73 100644 --- a/src/duckdb/src/include/duckdb/common/bswap.hpp +++ b/src/duckdb/src/include/duckdb/common/bswap.hpp @@ -11,8 +11,23 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/numeric_utils.hpp" +#include + namespace duckdb { +#ifndef DUCKDB_IS_BIG_ENDIAN +#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define DUCKDB_IS_BIG_ENDIAN 1 +#else +#define DUCKDB_IS_BIG_ENDIAN 0 +#endif +#endif + +#if defined(__clang__) || defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) +#define BSWAP16(x) __builtin_bswap16(static_cast(x)) +#define BSWAP32(x) __builtin_bswap32(static_cast(x)) +#define BSWAP64(x) __builtin_bswap64(static_cast(x)) +#else #define BSWAP16(x) ((uint16_t)((((uint16_t)(x)&0xff00) >> 8) | (((uint16_t)(x)&0x00ff) << 8))) #define BSWAP32(x) \ @@ -24,6 +39,11 @@ namespace duckdb { (((uint64_t)(x)&0x0000ff0000000000ull) >> 24) | (((uint64_t)(x)&0x000000ff00000000ull) >> 8) | \ (((uint64_t)(x)&0x00000000ff000000ull) << 8) | (((uint64_t)(x)&0x0000000000ff0000ull) << 24) | \ (((uint64_t)(x)&0x000000000000ff00ull) << 40) | (((uint64_t)(x)&0x00000000000000ffull) << 56))) +#endif + +static inline int8_t BSwap(const int8_t &x) { + return x; +} static inline uint8_t BSwap(const uint8_t &x) { return x; @@ -33,10 +53,18 @@ static inline uint16_t BSwap(const uint16_t &x) { return BSWAP16(x); } +static inline int16_t BSwap(const int16_t &x) { + return static_cast(BSWAP16(x)); +} + static inline uint32_t BSwap(const uint32_t &x) { return BSWAP32(x); } +static inline int32_t BSwap(const int32_t &x) { + return static_cast(BSWAP32(x)); +} + static inline uint64_t BSwap(const uint64_t &x) { return BSWAP64(x); } @@ -45,4 +73,48 @@ static inline int64_t BSwap(const int64_t &x) { return static_cast(BSWAP64(x)); } +static inline uhugeint_t BSwap(const uhugeint_t &x) { + return uhugeint_t(BSWAP64(x.upper), BSWAP64(x.lower)); +} + +static inline hugeint_t BSwap(const hugeint_t &x) { + return hugeint_t(static_cast(BSWAP64(x.upper)), BSWAP64(x.lower)); +} + +static inline float BSwap(const float &x) { + uint32_t temp; + std::memcpy(&temp, &x, sizeof(temp)); + temp = BSWAP32(temp); + float result; + std::memcpy(&result, &temp, sizeof(result)); + return result; +} + +static inline double BSwap(const double &x) { + uint64_t temp; + std::memcpy(&temp, &x, sizeof(temp)); + temp = BSWAP64(temp); + double result; + std::memcpy(&result, &temp, sizeof(result)); + return result; +} + +template +static inline T BSwapIfLE(const T &x) { +#if DUCKDB_IS_BIG_ENDIAN + return x; +#else + return BSwap(x); +#endif +} + +template +static inline T BSwapIfBE(const T &x) { +#if DUCKDB_IS_BIG_ENDIAN + return BSwap(x); +#else + return x; +#endif +} + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/compatible_with_ipp.hpp b/src/duckdb/src/include/duckdb/common/compatible_with_ipp.hpp new file mode 100644 index 000000000..bbedb0ee8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/compatible_with_ipp.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/memory_safety.hpp" + +#include +#include + +namespace duckdb { + +// This implementation is taken from the llvm-project, at this commit hash: +// https://github.com/llvm/llvm-project/blob/08bb121835be432ac52372f92845950628ce9a4a/libcxx/include/__memory/shared_ptr.h#353 +// originally named '__compatible_with' + +#if _LIBCPP_STD_VER >= 17 +template +struct __bounded_convertible_to_unbounded : std::false_type {}; + +template +struct __bounded_convertible_to_unbounded<_Up[_Np], T> : std::is_same, _Up[]> {}; + +template +struct compatible_with_t : std::_Or, __bounded_convertible_to_unbounded> {}; +#else +template +struct compatible_with_t : std::is_convertible {}; // NOLINT: invalid case style +#endif // _LIBCPP_STD_VER >= 17 + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/csv_writer.hpp b/src/duckdb/src/include/duckdb/common/csv_writer.hpp index b2d0e066e..188d1de7f 100644 --- a/src/duckdb/src/include/duckdb/common/csv_writer.hpp +++ b/src/duckdb/src/include/duckdb/common/csv_writer.hpp @@ -90,9 +90,6 @@ class CSVWriter { //! Closes the writer, optionally writes a postfix void Close(); - unique_ptr InitializeLocalWriteState(ClientContext &context, idx_t flush_size); - unique_ptr InitializeLocalWriteState(DatabaseInstance &db, idx_t flush_size); - vector> string_casts; idx_t BytesWritten(); diff --git a/src/duckdb/src/include/duckdb/common/deque.hpp b/src/duckdb/src/include/duckdb/common/deque.hpp index f5c8ba990..6b5d38826 100644 --- a/src/duckdb/src/include/duckdb/common/deque.hpp +++ b/src/duckdb/src/include/duckdb/common/deque.hpp @@ -8,8 +8,115 @@ #pragma once +#include "duckdb/common/assert.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/memory_safety.hpp" #include namespace duckdb { -using std::deque; -} + +template +class deque : public std::deque> { // NOLINT: matching name of std +public: + using original = std::deque>; + using original::original; + using value_type = typename original::value_type; + using allocator_type = typename original::allocator_type; + using size_type = typename original::size_type; + using difference_type = typename original::difference_type; + using reference = typename original::reference; + using const_reference = typename original::const_reference; + using pointer = typename original::pointer; + using const_pointer = typename original::const_pointer; + using iterator = typename original::iterator; + using const_iterator = typename original::const_iterator; + using reverse_iterator = typename original::reverse_iterator; + using const_reverse_iterator = typename original::const_reverse_iterator; + +private: + static inline void AssertIndexInBounds(idx_t index, idx_t size) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(index >= size)) { + throw InternalException("Attempted to access index %ld within deque of size %ld", index, size); + } +#endif + } + +public: +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif + inline void + clear() noexcept { // NOLINT: hiding on purpose + original::clear(); + } + + // Because we create the other constructor, the implicitly created constructor + // gets deleted, so we have to be explicit + deque() = default; + deque(original &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + template + deque(deque &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + + template + inline typename original::reference get(typename original::size_type __n) { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + template + inline typename original::const_reference get(typename original::size_type __n) const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED) { + AssertIndexInBounds(__n, original::size()); + } + return original::operator[](__n); + } + + typename original::reference operator[](typename original::size_type __n) { // NOLINT: hiding on purpose + return get(__n); + } + typename original::const_reference operator[](typename original::size_type __n) const { // NOLINT: hiding on purpose + return get(__n); + } + + typename original::reference front() { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty deque!"); + } + return get(0); + } + + typename original::const_reference front() const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty deque!"); + } + return get(0); + } + + typename original::reference back() { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty deque!"); + } + return get(original::size() - 1); + } + + typename original::const_reference back() const { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty deque!"); + } + return get(original::size() - 1); + } +}; + +template +using unsafe_deque = deque; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enable_shared_from_this_ipp.hpp b/src/duckdb/src/include/duckdb/common/enable_shared_from_this_ipp.hpp index 85cdd2205..5c707082e 100644 --- a/src/duckdb/src/include/duckdb/common/enable_shared_from_this_ipp.hpp +++ b/src/duckdb/src/include/duckdb/common/enable_shared_from_this_ipp.hpp @@ -1,3 +1,8 @@ +#pragma once + +#include "duckdb/common/shared_ptr_ipp.hpp" +#include "duckdb/common/weak_ptr_ipp.hpp" + namespace duckdb { template diff --git a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp index 07f50b98c..35a439411 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp @@ -1,6 +1,7 @@ #pragma once #include "duckdb/common/helper.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" namespace duckdb { @@ -26,8 +27,32 @@ struct EncryptionNonce { unique_ptr nonce; }; -class EncryptionEngine { +class AdditionalAuthenticatedData { +public: + explicit AdditionalAuthenticatedData(Allocator &allocator) + : additional_authenticated_data(make_uniq(allocator, INITIAL_AAD_CAPACITY)) { + } + virtual ~AdditionalAuthenticatedData(); + +public: + template + void WriteData(const T &val) { + additional_authenticated_data->WriteData(reinterpret_cast(&val), sizeof(val)); + } + +public: + void WriteStringData(const std::string &val) const; + data_ptr_t data() const; + idx_t size() const; + +private: + static constexpr uint32_t INITIAL_AAD_CAPACITY = 32; +protected: + unique_ptr additional_authenticated_data; +}; + +class EncryptionEngine { public: EncryptionEngine(); ~EncryptionEngine(); diff --git a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp index 55c3aed75..d37423430 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp @@ -17,7 +17,6 @@ namespace duckdb { class EncryptionKey { - public: explicit EncryptionKey(data_ptr_t encryption_key); ~EncryptionKey(); @@ -42,7 +41,6 @@ class EncryptionKey { }; class EncryptionKeyManager : public ObjectCacheEntry { - public: static EncryptionKeyManager &GetInternal(ObjectCache &cache); static EncryptionKeyManager &Get(ClientContext &context); @@ -66,6 +64,8 @@ class EncryptionKeyManager : public ObjectCacheEntry { static void KeyDerivationFunctionSHA256(data_ptr_t user_key, idx_t user_key_size, data_ptr_t salt, data_ptr_t derived_key); static string Base64Decode(const string &key); + + //! Generate a (non-cryptographically secure) random key ID static string GenerateRandomKeyID(); public: @@ -74,6 +74,7 @@ class EncryptionKeyManager : public ObjectCacheEntry { static constexpr idx_t DERIVED_KEY_LENGTH = 32; private: + mutable mutex lock; std::unordered_map derived_keys; }; diff --git a/src/duckdb/src/include/duckdb/common/encryption_state.hpp b/src/duckdb/src/include/duckdb/common/encryption_state.hpp index 32c0597a9..4aece4a2f 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_state.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_state.hpp @@ -14,7 +14,6 @@ namespace duckdb { class EncryptionTypes { - public: enum CipherType : uint8_t { INVALID = 0, GCM = 1, CTR = 2, CBC = 3 }; enum KeyDerivationFunction : uint8_t { DEFAULT = 0, SHA256 = 1, PBKDF2 = 2 }; @@ -27,7 +26,6 @@ class EncryptionTypes { }; class EncryptionState { - public: DUCKDB_API explicit EncryptionState(EncryptionTypes::CipherType cipher_p, idx_t key_len); DUCKDB_API virtual ~EncryptionState(); @@ -47,7 +45,6 @@ class EncryptionState { }; class EncryptionUtil { - public: DUCKDB_API explicit EncryptionUtil() {}; @@ -59,6 +56,11 @@ class EncryptionUtil { virtual ~EncryptionUtil() { } + + //! Whether the EncryptionUtil supports encryption (some may only support decryption) + DUCKDB_API virtual bool SupportsEncryption() { + return true; + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index d07e93d02..f4e7a9dba 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -78,6 +78,10 @@ enum class ArrowTypeInfoType : uint8_t; enum class ArrowVariableSizeType : uint8_t; +enum class AsyncResultType : uint8_t; + +enum class AsyncResultsExecutionMode : uint8_t; + enum class BinderType : uint8_t; enum class BindingMode : uint8_t; @@ -88,12 +92,16 @@ enum class BlockIteratorStateType : int8_t; enum class BlockState : uint8_t; +enum class BufferedIndexReplay : uint8_t; + enum class CAPIResultSetType : uint8_t; enum class CSVState : uint8_t; enum class CTEMaterialize : uint8_t; +enum class CachingMode : uint8_t; + enum class CatalogLookupBehavior : uint8_t; enum class CatalogType : uint8_t; @@ -202,6 +210,8 @@ enum class FunctionStability : uint8_t; enum class GateStatus : uint8_t; +enum class GeometryType : uint8_t; + enum class HLLStorageType : uint8_t; enum class HTTPStatusCode : uint16_t; @@ -256,7 +266,9 @@ enum class MergeActionType : uint8_t; enum class MetaPipelineType : uint8_t; -enum class MetricsType : uint8_t; +enum class MetricGroup : uint8_t; + +enum class MetricType : uint8_t; enum class MultiFileColumnMappingMode : uint8_t; @@ -294,8 +306,6 @@ enum class ParseInfoType : uint8_t; enum class ParserExtensionResultType : uint8_t; -enum class PartitionSortStage : uint8_t; - enum class PartitionedColumnDataType : uint8_t; enum class PartitionedTupleDataType : uint8_t; @@ -304,6 +314,8 @@ enum class PendingExecutionResult : uint8_t; enum class PhysicalOperatorType : uint8_t; +enum class PhysicalTableScanExecutionStrategy : uint8_t; + enum class PhysicalType : uint8_t; enum class PragmaType : uint8_t; @@ -322,8 +334,14 @@ enum class QuantileSerializationType : uint8_t; enum class QueryNodeType : uint8_t; +enum class QueryResultMemoryType : uint8_t; + +enum class QueryResultOutputType : uint8_t; + enum class QueryResultType : uint8_t; +enum class RecoveryMode : uint8_t; + enum class RelationType : uint8_t; enum class RenderMode : uint8_t; @@ -430,6 +448,10 @@ enum class VariantChildLookupMode : uint8_t; enum class VariantLogicalType : uint8_t; +enum class VariantStatsShreddingState : uint8_t; + +enum class VariantValueType : uint8_t; + enum class VectorAuxiliaryDataType : uint8_t; enum class VectorBufferType : uint8_t; @@ -440,6 +462,8 @@ enum class VerificationType : uint8_t; enum class VerifyExistenceType : uint8_t; +enum class VertexType : uint8_t; + enum class WALType : uint8_t; enum class WindowAggregationMode : uint32_t; @@ -520,6 +544,12 @@ const char* EnumUtil::ToChars(ArrowTypeInfoType value); template<> const char* EnumUtil::ToChars(ArrowVariableSizeType value); +template<> +const char* EnumUtil::ToChars(AsyncResultType value); + +template<> +const char* EnumUtil::ToChars(AsyncResultsExecutionMode value); + template<> const char* EnumUtil::ToChars(BinderType value); @@ -535,6 +565,9 @@ const char* EnumUtil::ToChars(BlockIteratorStateType val template<> const char* EnumUtil::ToChars(BlockState value); +template<> +const char* EnumUtil::ToChars(BufferedIndexReplay value); + template<> const char* EnumUtil::ToChars(CAPIResultSetType value); @@ -544,6 +577,9 @@ const char* EnumUtil::ToChars(CSVState value); template<> const char* EnumUtil::ToChars(CTEMaterialize value); +template<> +const char* EnumUtil::ToChars(CachingMode value); + template<> const char* EnumUtil::ToChars(CatalogLookupBehavior value); @@ -706,6 +742,9 @@ const char* EnumUtil::ToChars(FunctionStability value); template<> const char* EnumUtil::ToChars(GateStatus value); +template<> +const char* EnumUtil::ToChars(GeometryType value); + template<> const char* EnumUtil::ToChars(HLLStorageType value); @@ -788,7 +827,10 @@ template<> const char* EnumUtil::ToChars(MetaPipelineType value); template<> -const char* EnumUtil::ToChars(MetricsType value); +const char* EnumUtil::ToChars(MetricGroup value); + +template<> +const char* EnumUtil::ToChars(MetricType value); template<> const char* EnumUtil::ToChars(MultiFileColumnMappingMode value); @@ -844,9 +886,6 @@ const char* EnumUtil::ToChars(ParseInfoType value); template<> const char* EnumUtil::ToChars(ParserExtensionResultType value); -template<> -const char* EnumUtil::ToChars(PartitionSortStage value); - template<> const char* EnumUtil::ToChars(PartitionedColumnDataType value); @@ -859,6 +898,9 @@ const char* EnumUtil::ToChars(PendingExecutionResult val template<> const char* EnumUtil::ToChars(PhysicalOperatorType value); +template<> +const char* EnumUtil::ToChars(PhysicalTableScanExecutionStrategy value); + template<> const char* EnumUtil::ToChars(PhysicalType value); @@ -886,9 +928,18 @@ const char* EnumUtil::ToChars(QuantileSerializationTy template<> const char* EnumUtil::ToChars(QueryNodeType value); +template<> +const char* EnumUtil::ToChars(QueryResultMemoryType value); + +template<> +const char* EnumUtil::ToChars(QueryResultOutputType value); + template<> const char* EnumUtil::ToChars(QueryResultType value); +template<> +const char* EnumUtil::ToChars(RecoveryMode value); + template<> const char* EnumUtil::ToChars(RelationType value); @@ -1048,6 +1099,12 @@ const char* EnumUtil::ToChars(VariantChildLookupMode val template<> const char* EnumUtil::ToChars(VariantLogicalType value); +template<> +const char* EnumUtil::ToChars(VariantStatsShreddingState value); + +template<> +const char* EnumUtil::ToChars(VariantValueType value); + template<> const char* EnumUtil::ToChars(VectorAuxiliaryDataType value); @@ -1063,6 +1120,9 @@ const char* EnumUtil::ToChars(VerificationType value); template<> const char* EnumUtil::ToChars(VerifyExistenceType value); +template<> +const char* EnumUtil::ToChars(VertexType value); + template<> const char* EnumUtil::ToChars(WALType value); @@ -1148,6 +1208,12 @@ ArrowTypeInfoType EnumUtil::FromString(const char *value); template<> ArrowVariableSizeType EnumUtil::FromString(const char *value); +template<> +AsyncResultType EnumUtil::FromString(const char *value); + +template<> +AsyncResultsExecutionMode EnumUtil::FromString(const char *value); + template<> BinderType EnumUtil::FromString(const char *value); @@ -1163,6 +1229,9 @@ BlockIteratorStateType EnumUtil::FromString(const char * template<> BlockState EnumUtil::FromString(const char *value); +template<> +BufferedIndexReplay EnumUtil::FromString(const char *value); + template<> CAPIResultSetType EnumUtil::FromString(const char *value); @@ -1172,6 +1241,9 @@ CSVState EnumUtil::FromString(const char *value); template<> CTEMaterialize EnumUtil::FromString(const char *value); +template<> +CachingMode EnumUtil::FromString(const char *value); + template<> CatalogLookupBehavior EnumUtil::FromString(const char *value); @@ -1334,6 +1406,9 @@ FunctionStability EnumUtil::FromString(const char *value); template<> GateStatus EnumUtil::FromString(const char *value); +template<> +GeometryType EnumUtil::FromString(const char *value); + template<> HLLStorageType EnumUtil::FromString(const char *value); @@ -1416,7 +1491,10 @@ template<> MetaPipelineType EnumUtil::FromString(const char *value); template<> -MetricsType EnumUtil::FromString(const char *value); +MetricGroup EnumUtil::FromString(const char *value); + +template<> +MetricType EnumUtil::FromString(const char *value); template<> MultiFileColumnMappingMode EnumUtil::FromString(const char *value); @@ -1472,9 +1550,6 @@ ParseInfoType EnumUtil::FromString(const char *value); template<> ParserExtensionResultType EnumUtil::FromString(const char *value); -template<> -PartitionSortStage EnumUtil::FromString(const char *value); - template<> PartitionedColumnDataType EnumUtil::FromString(const char *value); @@ -1487,6 +1562,9 @@ PendingExecutionResult EnumUtil::FromString(const char * template<> PhysicalOperatorType EnumUtil::FromString(const char *value); +template<> +PhysicalTableScanExecutionStrategy EnumUtil::FromString(const char *value); + template<> PhysicalType EnumUtil::FromString(const char *value); @@ -1514,9 +1592,18 @@ QuantileSerializationType EnumUtil::FromString(const template<> QueryNodeType EnumUtil::FromString(const char *value); +template<> +QueryResultMemoryType EnumUtil::FromString(const char *value); + +template<> +QueryResultOutputType EnumUtil::FromString(const char *value); + template<> QueryResultType EnumUtil::FromString(const char *value); +template<> +RecoveryMode EnumUtil::FromString(const char *value); + template<> RelationType EnumUtil::FromString(const char *value); @@ -1676,6 +1763,12 @@ VariantChildLookupMode EnumUtil::FromString(const char * template<> VariantLogicalType EnumUtil::FromString(const char *value); +template<> +VariantStatsShreddingState EnumUtil::FromString(const char *value); + +template<> +VariantValueType EnumUtil::FromString(const char *value); + template<> VectorAuxiliaryDataType EnumUtil::FromString(const char *value); @@ -1691,6 +1784,9 @@ VerificationType EnumUtil::FromString(const char *value); template<> VerifyExistenceType EnumUtil::FromString(const char *value); +template<> +VertexType EnumUtil::FromString(const char *value); + template<> WALType EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/active_transaction_state.hpp b/src/duckdb/src/include/duckdb/common/enums/active_transaction_state.hpp new file mode 100644 index 000000000..018b48054 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/active_transaction_state.hpp @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/active_transaction_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +enum class ActiveTransactionState { UNSET, OTHER_TRANSACTIONS, NO_OTHER_TRANSACTIONS }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp b/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp index 321fc25b6..5cedd908b 100644 --- a/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp @@ -16,7 +16,10 @@ enum class CheckpointAbort : uint8_t { NO_ABORT = 0, DEBUG_ABORT_BEFORE_TRUNCATE = 1, DEBUG_ABORT_BEFORE_HEADER = 2, - DEBUG_ABORT_AFTER_FREE_LIST_WRITE = 3 + DEBUG_ABORT_AFTER_FREE_LIST_WRITE = 3, + DEBUG_ABORT_BEFORE_WAL_FINISH = 4, + DEBUG_ABORT_BEFORE_MOVING_RECOVERY = 5, + DEBUG_ABORT_BEFORE_DELETING_CHECKPOINT_WAL = 6 }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp index 1dda5ee64..5198f7627 100644 --- a/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/compression_type.hpp @@ -36,8 +36,48 @@ enum class CompressionType : uint8_t { COMPRESSION_COUNT // This has to stay the last entry of the type! }; -bool CompressionTypeIsDeprecated(CompressionType compression_type, - optional_ptr storage_manager = nullptr); +struct CompressionAvailabilityResult { +private: + enum class UnavailableReason : uint8_t { + AVAILABLE, + //! Introduced later, not available to this version + NOT_AVAILABLE_YET, + //! Used to be available, but isnt anymore + DEPRECATED + }; + +public: + CompressionAvailabilityResult() = default; + static CompressionAvailabilityResult Deprecated() { + return CompressionAvailabilityResult(UnavailableReason::DEPRECATED); + } + static CompressionAvailabilityResult NotAvailableYet() { + return CompressionAvailabilityResult(UnavailableReason::NOT_AVAILABLE_YET); + } + +public: + bool IsAvailable() const { + return reason == UnavailableReason::AVAILABLE; + } + bool IsDeprecated() { + D_ASSERT(!IsAvailable()); + return reason == UnavailableReason::DEPRECATED; + } + bool IsNotAvailableYet() { + D_ASSERT(!IsAvailable()); + return reason == UnavailableReason::NOT_AVAILABLE_YET; + } + +private: + explicit CompressionAvailabilityResult(UnavailableReason reason) : reason(reason) { + } + +public: + UnavailableReason reason = UnavailableReason::AVAILABLE; +}; + +CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compression_type, + optional_ptr storage_manager = nullptr); vector ListCompressionTypes(void); CompressionType CompressionTypeFromString(const string &str); string CompressionTypeToString(CompressionType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/database_modification_type.hpp b/src/duckdb/src/include/duckdb/common/enums/database_modification_type.hpp new file mode 100644 index 000000000..8f108986d --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/database_modification_type.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/database_modification_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +struct DatabaseModificationType { +public: + static constexpr idx_t INSERT_DATA = 1ULL << 0ULL; + static constexpr idx_t DELETE_DATA = 1ULL << 1ULL; + static constexpr idx_t UPDATE_DATA = 1ULL << 2ULL; + static constexpr idx_t ALTER_TABLE = 1ULL << 3ULL; + static constexpr idx_t CREATE_CATALOG_ENTRY = 1ULL << 4ULL; + static constexpr idx_t DROP_CATALOG_ENTRY = 1ULL << 5ULL; + static constexpr idx_t SEQUENCE = 1ULL << 6ULL; + static constexpr idx_t CREATE_INDEX = 1ULL << 7ULL; + static constexpr idx_t INSERT_DATA_WITH_INDEX = 1ULL << 8ULL; + + constexpr DatabaseModificationType() : value(0) { + } + constexpr DatabaseModificationType(idx_t value) : value(value) { // NOLINT : allow implicit conversion + } + + inline constexpr DatabaseModificationType operator|(DatabaseModificationType b) const { + return DatabaseModificationType(value | b.value); + } + inline DatabaseModificationType &operator|=(DatabaseModificationType b) { + value |= b.value; + return *this; + } + + bool InsertData() const { + return value & INSERT_DATA; + } + bool InsertDataWithIndex() const { + return value & INSERT_DATA_WITH_INDEX; + } + bool DeleteData() const { + return value & DELETE_DATA; + } + bool UpdateData() const { + return value & UPDATE_DATA; + } + bool AlterTable() const { + return value & ALTER_TABLE; + } + bool CreateCatalogEntry() const { + return value & CREATE_CATALOG_ENTRY; + } + bool DropCatalogEntry() const { + return value & DROP_CATALOG_ENTRY; + } + bool Sequence() const { + return value & SEQUENCE; + } + bool CreateIndex() const { + return value & CREATE_INDEX; + } + +private: + idx_t value; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp b/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp index 6635ca454..149e8c33f 100644 --- a/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/explain_format.hpp @@ -12,6 +12,6 @@ namespace duckdb { -enum class ExplainFormat : uint8_t { DEFAULT, TEXT, JSON, HTML, GRAPHVIZ, YAML }; +enum class ExplainFormat : uint8_t { DEFAULT, TEXT, JSON, HTML, GRAPHVIZ, YAML, MERMAID }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/index_removal_type.hpp b/src/duckdb/src/include/duckdb/common/enums/index_removal_type.hpp new file mode 100644 index 000000000..0591a2003 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/index_removal_type.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/index_removal_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +enum class IndexRemovalType { + //! Remove from main index, insert into deleted_rows_in_use + MAIN_INDEX, + //! Remove from main index only + MAIN_INDEX_ONLY, + //! Revert MAIN_INDEX, i.e. append to main index and remove from deleted_rows_in_use + REVERT_MAIN_INDEX, + //! Revert MAIN_INDEX_ONLY, i.e. append to main index + REVERT_MAIN_INDEX_ONLY, + //! Remove from deleted_rows_in_use + DELETED_ROWS_IN_USE +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/memory_tag.hpp b/src/duckdb/src/include/duckdb/common/enums/memory_tag.hpp index 744107e43..55dc0b62e 100644 --- a/src/duckdb/src/include/duckdb/common/enums/memory_tag.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/memory_tag.hpp @@ -27,8 +27,9 @@ enum class MemoryTag : uint8_t { EXTENSION = 11, TRANSACTION = 12, EXTERNAL_FILE_CACHE = 13, + WINDOW = 14 }; -static constexpr const idx_t MEMORY_TAG_COUNT = 14; +static constexpr const idx_t MEMORY_TAG_COUNT = 15; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp index 8fd2790ab..82208c895 100644 --- a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp @@ -1,101 +1,176 @@ -//------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// // DuckDB // +// duckdb/common/enums/metric_type.hpp // -// duckdb/common/enums/metrics_type.hpp -// // This file is automatically generated by scripts/generate_metric_enums.py // Do not edit this file manually, your changes will be overwritten -//------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/types/value.hpp" #include "duckdb/common/unordered_set.hpp" -#include "duckdb/common/unordered_map.hpp" #include "duckdb/common/constants.hpp" -#include "duckdb/common/enum_util.hpp" #include "duckdb/common/enums/optimizer_type.hpp" namespace duckdb { -enum class MetricsType : uint8_t { - QUERY_NAME, - BLOCKED_THREAD_TIME, - CPU_TIME, - EXTRA_INFO, - CUMULATIVE_CARDINALITY, - OPERATOR_TYPE, - OPERATOR_CARDINALITY, - CUMULATIVE_ROWS_SCANNED, - OPERATOR_ROWS_SCANNED, - OPERATOR_TIMING, - RESULT_SET_SIZE, - LATENCY, - ROWS_RETURNED, - OPERATOR_NAME, - SYSTEM_PEAK_BUFFER_MEMORY, - SYSTEM_PEAK_TEMP_DIR_SIZE, - TOTAL_BYTES_READ, - TOTAL_BYTES_WRITTEN, - ALL_OPTIMIZERS, - CUMULATIVE_OPTIMIZER_TIMING, - PLANNER, - PLANNER_BINDING, - PHYSICAL_PLANNER, - PHYSICAL_PLANNER_COLUMN_BINDING, - PHYSICAL_PLANNER_RESOLVE_TYPES, - PHYSICAL_PLANNER_CREATE_PLAN, - OPTIMIZER_EXPRESSION_REWRITER, - OPTIMIZER_FILTER_PULLUP, - OPTIMIZER_FILTER_PUSHDOWN, - OPTIMIZER_EMPTY_RESULT_PULLUP, - OPTIMIZER_CTE_FILTER_PUSHER, - OPTIMIZER_REGEX_RANGE, - OPTIMIZER_IN_CLAUSE, - OPTIMIZER_JOIN_ORDER, - OPTIMIZER_DELIMINATOR, - OPTIMIZER_UNNEST_REWRITER, - OPTIMIZER_UNUSED_COLUMNS, - OPTIMIZER_STATISTICS_PROPAGATION, - OPTIMIZER_COMMON_SUBEXPRESSIONS, - OPTIMIZER_COMMON_AGGREGATE, - OPTIMIZER_COLUMN_LIFETIME, - OPTIMIZER_BUILD_SIDE_PROBE_SIDE, - OPTIMIZER_LIMIT_PUSHDOWN, - OPTIMIZER_TOP_N, - OPTIMIZER_COMPRESSED_MATERIALIZATION, - OPTIMIZER_DUPLICATE_GROUPS, - OPTIMIZER_REORDER_FILTER, - OPTIMIZER_SAMPLING_PUSHDOWN, - OPTIMIZER_JOIN_FILTER_PUSHDOWN, - OPTIMIZER_EXTENSION, - OPTIMIZER_MATERIALIZED_CTE, - OPTIMIZER_SUM_REWRITER, - OPTIMIZER_LATE_MATERIALIZATION, - OPTIMIZER_CTE_INLINING, +enum class MetricGroup : uint8_t { + ALL, + CORE, + DEFAULT, + EXECUTION, + FILE, + OPERATOR, + OPTIMIZER, + PHASE_TIMING, + INVALID, +}; + +enum class MetricType : uint8_t { + // Core metrics + CPU_TIME, + CUMULATIVE_CARDINALITY, + CUMULATIVE_ROWS_SCANNED, + EXTRA_INFO, + LATENCY, + QUERY_NAME, + RESULT_SET_SIZE, + ROWS_RETURNED, + // Execution metrics + BLOCKED_THREAD_TIME, + SYSTEM_PEAK_BUFFER_MEMORY, + SYSTEM_PEAK_TEMP_DIR_SIZE, + TOTAL_MEMORY_ALLOCATED, + // File metrics + ATTACH_LOAD_STORAGE_LATENCY, + ATTACH_REPLAY_WAL_LATENCY, + CHECKPOINT_LATENCY, + COMMIT_LOCAL_STORAGE_LATENCY, + TOTAL_BYTES_READ, + TOTAL_BYTES_WRITTEN, + WAITING_TO_ATTACH_LATENCY, + WAL_REPLAY_ENTRY_COUNT, + WRITE_TO_WAL_LATENCY, + // Operator metrics + OPERATOR_CARDINALITY, + OPERATOR_NAME, + OPERATOR_ROWS_SCANNED, + OPERATOR_TIMING, + OPERATOR_TYPE, + // Optimizer metrics + OPTIMIZER_EXPRESSION_REWRITER, + OPTIMIZER_FILTER_PULLUP, + OPTIMIZER_FILTER_PUSHDOWN, + OPTIMIZER_EMPTY_RESULT_PULLUP, + OPTIMIZER_CTE_FILTER_PUSHER, + OPTIMIZER_REGEX_RANGE, + OPTIMIZER_IN_CLAUSE, + OPTIMIZER_JOIN_ORDER, + OPTIMIZER_DELIMINATOR, + OPTIMIZER_UNNEST_REWRITER, + OPTIMIZER_UNUSED_COLUMNS, + OPTIMIZER_STATISTICS_PROPAGATION, + OPTIMIZER_COMMON_SUBEXPRESSIONS, + OPTIMIZER_COMMON_AGGREGATE, + OPTIMIZER_COLUMN_LIFETIME, + OPTIMIZER_BUILD_SIDE_PROBE_SIDE, + OPTIMIZER_LIMIT_PUSHDOWN, + OPTIMIZER_ROW_GROUP_PRUNER, + OPTIMIZER_TOP_N, + OPTIMIZER_TOP_N_WINDOW_ELIMINATION, + OPTIMIZER_COMPRESSED_MATERIALIZATION, + OPTIMIZER_DUPLICATE_GROUPS, + OPTIMIZER_REORDER_FILTER, + OPTIMIZER_SAMPLING_PUSHDOWN, + OPTIMIZER_JOIN_FILTER_PUSHDOWN, + OPTIMIZER_EXTENSION, + OPTIMIZER_MATERIALIZED_CTE, + OPTIMIZER_SUM_REWRITER, + OPTIMIZER_LATE_MATERIALIZATION, + OPTIMIZER_CTE_INLINING, + OPTIMIZER_COMMON_SUBPLAN, + OPTIMIZER_JOIN_ELIMINATION, + // PhaseTiming metrics + ALL_OPTIMIZERS, + CUMULATIVE_OPTIMIZER_TIMING, + PHYSICAL_PLANNER, + PHYSICAL_PLANNER_COLUMN_BINDING, + PHYSICAL_PLANNER_CREATE_PLAN, + PHYSICAL_PLANNER_RESOLVE_TYPES, + PLANNER, + PLANNER_BINDING, }; -struct MetricsTypeHashFunction { - uint64_t operator()(const MetricsType &index) const { - return std::hash()(static_cast(index)); - } +struct MetricTypeHashFunction { + uint64_t operator()(const MetricType &index) const { + return std::hash()(static_cast(index)); + } }; -typedef unordered_set profiler_settings_t; -typedef unordered_map profiler_metrics_t; +typedef unordered_set profiler_settings_t; +typedef unordered_map profiler_metrics_t; class MetricsUtils { public: - static profiler_settings_t GetOptimizerMetrics(); - static profiler_settings_t GetPhaseTimingMetrics(); + static constexpr uint8_t START_CORE = static_cast(MetricType::CPU_TIME); + static constexpr uint8_t END_CORE = static_cast(MetricType::ROWS_RETURNED); - static MetricsType GetOptimizerMetricByType(OptimizerType type); - static OptimizerType GetOptimizerTypeByMetric(MetricsType type); + static constexpr uint8_t START_EXECUTION = static_cast(MetricType::BLOCKED_THREAD_TIME); + static constexpr uint8_t END_EXECUTION = static_cast(MetricType::TOTAL_MEMORY_ALLOCATED); - static bool IsOptimizerMetric(MetricsType type); - static bool IsPhaseTimingMetric(MetricsType type); - static bool IsQueryGlobalMetric(MetricsType type); -}; + static constexpr uint8_t START_FILE = static_cast(MetricType::ATTACH_LOAD_STORAGE_LATENCY); + static constexpr uint8_t END_FILE = static_cast(MetricType::WRITE_TO_WAL_LATENCY); + + static constexpr uint8_t START_OPERATOR = static_cast(MetricType::OPERATOR_CARDINALITY); + static constexpr uint8_t END_OPERATOR = static_cast(MetricType::OPERATOR_TYPE); + + static constexpr uint8_t START_OPTIMIZER = static_cast(MetricType::OPTIMIZER_EXPRESSION_REWRITER); + static constexpr uint8_t END_OPTIMIZER = static_cast(MetricType::OPTIMIZER_JOIN_ELIMINATION); + + static constexpr uint8_t START_PHASE_TIMING = static_cast(MetricType::ALL_OPTIMIZERS); + static constexpr uint8_t END_PHASE_TIMING = static_cast(MetricType::PLANNER_BINDING); + +public: + // All metrics + static profiler_settings_t GetAllMetrics(); + static profiler_settings_t GetMetricsByGroupType(MetricGroup type); + + // Core metrics + static profiler_settings_t GetCoreMetrics(); + static bool IsCoreMetric(MetricType type); + + // Default metrics + static profiler_settings_t GetDefaultMetrics(); + static bool IsDefaultMetric(MetricType type); + + // Execution metrics + static profiler_settings_t GetExecutionMetrics(); + static bool IsExecutionMetric(MetricType type); + + // File metrics + static profiler_settings_t GetFileMetrics(); + static bool IsFileMetric(MetricType type); + + // Operator metrics + static profiler_settings_t GetOperatorMetrics(); + static bool IsOperatorMetric(MetricType type); + + // Optimizer metrics + static profiler_settings_t GetOptimizerMetrics(); + static bool IsOptimizerMetric(MetricType type); + static MetricType GetOptimizerMetricByType(OptimizerType type); + static OptimizerType GetOptimizerTypeByMetric(MetricType type); + + // PhaseTiming metrics + static profiler_settings_t GetPhaseTimingMetrics(); + static bool IsPhaseTimingMetric(MetricType type); + + // RootScope metrics + static profiler_settings_t GetRootScopeMetrics(); + static bool IsRootScopeMetric(MetricType type); +}; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp index c539a5636..5a161dcbe 100644 --- a/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/operator_result_type.hpp @@ -42,6 +42,20 @@ enum class OperatorFinalResultType : uint8_t { FINISHED, BLOCKED }; //! BLOCKED means the source is currently blocked, e.g. by some async I/O enum class SourceResultType : uint8_t { HAVE_MORE_OUTPUT, FINISHED, BLOCKED }; +//! AsyncResultType is used to indicate the result of a AsyncResult, in the context of a wider operation being executed +enum class AsyncResultType : uint8_t { + INVALID, // current result is in an invalid state (eg: it's in the process of being initialized) + IMPLICIT, // current result depends on external context (eg: in the context of TableFunctions, either FINISHED or + // HAVE_MORE_OUTPUT depending on output_chunk.size()) + HAVE_MORE_OUTPUT, // current result is not completed, finished (eg: in the context of TableFunctions, function + // accept more iterations and might produce further results) + FINISHED, // current result is completed, no subsequent calls on the same state should be attempted + BLOCKED // current result is blocked, no subsequent calls on the same state should be attempted (eg: in the context + // of AsyncResult, BLOCKED will be associated with a vector of AsyncTasks to be scheduled) +}; + +bool ExtractSourceResultType(AsyncResultType in, SourceResultType &out); + //! The SinkResultType is used to indicate the result of data flowing into a sink //! There are three possible results: //! NEED_MORE_INPUT means the sink needs more input diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index b57823028..7f6864aac 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -32,7 +32,9 @@ enum class OptimizerType : uint32_t { COLUMN_LIFETIME, BUILD_SIDE_PROBE_SIDE, LIMIT_PUSHDOWN, + ROW_GROUP_PRUNER, TOP_N, + TOP_N_WINDOW_ELIMINATION, COMPRESSED_MATERIALIZATION, DUPLICATE_GROUPS, REORDER_FILTER, @@ -42,7 +44,9 @@ enum class OptimizerType : uint32_t { MATERIALIZED_CTE, SUM_REWRITER, LATE_MATERIALIZATION, - CTE_INLINING + CTE_INLINING, + COMMON_SUBPLAN, + JOIN_ELIMINATION }; string OptimizerTypeToString(OptimizerType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp index 1a416d546..9cd1206b9 100644 --- a/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/profiler_format.hpp @@ -12,6 +12,6 @@ namespace duckdb { -enum class ProfilerPrintFormat : uint8_t { QUERY_TREE, JSON, QUERY_TREE_OPTIMIZER, NO_OUTPUT, HTML, GRAPHVIZ }; +enum class ProfilerPrintFormat : uint8_t { QUERY_TREE, JSON, QUERY_TREE_OPTIMIZER, NO_OUTPUT, HTML, GRAPHVIZ, MERMAID }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp b/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp index 302b2f369..bca6af491 100644 --- a/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/relation_type.hpp @@ -43,7 +43,8 @@ enum class RelationType : uint8_t { VIEW_RELATION, QUERY_RELATION, DELIM_JOIN_RELATION, - DELIM_GET_RELATION + DELIM_GET_RELATION, + EXTENSION_RELATION = 255 }; string RelationTypeToString(RelationType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp index b6c9d08ae..fe7d6d960 100644 --- a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp @@ -11,6 +11,8 @@ #include "duckdb/common/constants.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/query_parameters.hpp" +#include "duckdb/common/enums/database_modification_type.hpp" namespace duckdb { @@ -67,8 +69,9 @@ class ClientContext; //! A struct containing various properties of a SQL statement struct StatementProperties { StatementProperties() - : requires_valid_transaction(true), allow_stream_result(false), bound_all_parameters(true), - return_type(StatementReturnType::QUERY_RESULT), parameter_count(0), always_require_rebind(false) { + : requires_valid_transaction(true), output_type(QueryResultOutputType::FORCE_MATERIALIZED), + bound_all_parameters(true), return_type(StatementReturnType::QUERY_RESULT), parameter_count(0), + always_require_rebind(false) { } struct CatalogIdentity { @@ -84,15 +87,20 @@ struct StatementProperties { } }; + struct ModificationInfo { + CatalogIdentity identity; + DatabaseModificationType modifications; + }; + //! The set of databases this statement will read from unordered_map read_databases; //! The set of databases this statement will modify - unordered_map modified_databases; + unordered_map modified_databases; //! Whether or not the statement requires a valid transaction. Almost all statements require this, with the //! exception of ROLLBACK bool requires_valid_transaction; //! Whether or not the result can be streamed to the client - bool allow_stream_result; + QueryResultOutputType output_type; //! Whether or not all parameters have successfully had their types determined bool bound_all_parameters; //! What type of data the statement returns @@ -107,8 +115,7 @@ struct StatementProperties { } void RegisterDBRead(Catalog &catalog, ClientContext &context); - - void RegisterDBModify(Catalog &catalog, ClientContext &context); + void RegisterDBModify(Catalog &catalog, ClientContext &context, DatabaseModificationType modification); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception.hpp b/src/duckdb/src/include/duckdb/common/exception.hpp index 480dd2385..17d430201 100644 --- a/src/duckdb/src/include/duckdb/common/exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception.hpp @@ -94,15 +94,16 @@ enum class ExceptionType : uint8_t { class Exception : public std::runtime_error { public: DUCKDB_API Exception(ExceptionType exception_type, const string &message); - DUCKDB_API Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API Exception(const unordered_map &extra_info, ExceptionType exception_type, + const string &message); public: DUCKDB_API static string ExceptionTypeToString(ExceptionType type); DUCKDB_API static ExceptionType StringToExceptionType(const string &type); template - static string ConstructMessage(const string &msg, ARGS... params) { + static string ConstructMessage(const string &msg, ARGS const &...params) { const std::size_t num_args = sizeof...(ARGS); if (num_args == 0) { return msg; @@ -122,8 +123,9 @@ class Exception : public std::runtime_error { //! Whether this exception type can occur during execution of a query DUCKDB_API static bool IsExecutionError(ExceptionType type); DUCKDB_API static string ToJSON(ExceptionType type, const string &message); - DUCKDB_API static string ToJSON(ExceptionType type, const string &message, - const unordered_map &extra_info); + + DUCKDB_API static string ToJSON(const unordered_map &extra_info, ExceptionType type, + const string &message); DUCKDB_API static bool InvalidatesTransaction(ExceptionType exception_type); DUCKDB_API static bool InvalidatesDatabase(ExceptionType exception_type); @@ -131,8 +133,8 @@ class Exception : public std::runtime_error { DUCKDB_API static string ConstructMessageRecursive(const string &msg, std::vector &values); template - static string ConstructMessageRecursive(const string &msg, std::vector &values, T param, - ARGS... params) { + static string ConstructMessageRecursive(const string &msg, std::vector &values, + const T ¶m, ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); return ConstructMessageRecursive(msg, values, params...); } @@ -155,8 +157,8 @@ class ConnectionException : public Exception { DUCKDB_API explicit ConnectionException(const string &msg); template - explicit ConnectionException(const string &msg, ARGS... params) - : ConnectionException(ConstructMessage(msg, params...)) { + explicit ConnectionException(const string &msg, ARGS &&...params) + : ConnectionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -165,8 +167,8 @@ class PermissionException : public Exception { DUCKDB_API explicit PermissionException(const string &msg); template - explicit PermissionException(const string &msg, ARGS... params) - : PermissionException(ConstructMessage(msg, params...)) { + explicit PermissionException(const string &msg, ARGS &&...params) + : PermissionException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -175,8 +177,8 @@ class OutOfRangeException : public Exception { DUCKDB_API explicit OutOfRangeException(const string &msg); template - explicit OutOfRangeException(const string &msg, ARGS... params) - : OutOfRangeException(ConstructMessage(msg, params...)) { + explicit OutOfRangeException(const string &msg, ARGS &&...params) + : OutOfRangeException(ConstructMessage(msg, std::forward(params)...)) { } DUCKDB_API OutOfRangeException(const int64_t value, const PhysicalType orig_type, const PhysicalType new_type); DUCKDB_API OutOfRangeException(const hugeint_t value, const PhysicalType orig_type, const PhysicalType new_type); @@ -189,8 +191,8 @@ class OutOfMemoryException : public Exception { DUCKDB_API explicit OutOfMemoryException(const string &msg); template - explicit OutOfMemoryException(const string &msg, ARGS... params) - : OutOfMemoryException(ConstructMessage(msg, params...)) { + explicit OutOfMemoryException(const string &msg, ARGS &&...params) + : OutOfMemoryException(ConstructMessage(msg, std::forward(params)...)) { } private: @@ -202,7 +204,8 @@ class SyntaxException : public Exception { DUCKDB_API explicit SyntaxException(const string &msg); template - explicit SyntaxException(const string &msg, ARGS... params) : SyntaxException(ConstructMessage(msg, params...)) { + explicit SyntaxException(const string &msg, ARGS &&...params) + : SyntaxException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -211,8 +214,8 @@ class ConstraintException : public Exception { DUCKDB_API explicit ConstraintException(const string &msg); template - explicit ConstraintException(const string &msg, ARGS... params) - : ConstraintException(ConstructMessage(msg, params...)) { + explicit ConstraintException(const string &msg, ARGS &&...params) + : ConstraintException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -221,25 +224,27 @@ class DependencyException : public Exception { DUCKDB_API explicit DependencyException(const string &msg); template - explicit DependencyException(const string &msg, ARGS... params) - : DependencyException(ConstructMessage(msg, params...)) { + explicit DependencyException(const string &msg, ARGS &&...params) + : DependencyException(ConstructMessage(msg, std::forward(params)...)) { } }; class IOException : public Exception { public: DUCKDB_API explicit IOException(const string &msg); - DUCKDB_API explicit IOException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit IOException(const unordered_map &extra_info, const string &msg); explicit IOException(ExceptionType exception_type, const string &msg) : Exception(exception_type, msg) { } template - explicit IOException(const string &msg, ARGS... params) : IOException(ConstructMessage(msg, params...)) { + explicit IOException(const string &msg, ARGS &&...params) + : IOException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit IOException(const string &msg, const unordered_map &extra_info, ARGS... params) - : IOException(ConstructMessage(msg, params...), extra_info) { + explicit IOException(const unordered_map &extra_info, const string &msg, ARGS &&...params) + : IOException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -248,18 +253,24 @@ class MissingExtensionException : public Exception { DUCKDB_API explicit MissingExtensionException(const string &msg); template - explicit MissingExtensionException(const string &msg, ARGS... params) - : MissingExtensionException(ConstructMessage(msg, params...)) { + explicit MissingExtensionException(const string &msg, ARGS &&...params) + : MissingExtensionException(ConstructMessage(msg, std::forward(params)...)) { } }; class NotImplementedException : public Exception { public: DUCKDB_API explicit NotImplementedException(const string &msg); + explicit NotImplementedException(const unordered_map &extra_info, const string &msg); template - explicit NotImplementedException(const string &msg, ARGS... params) - : NotImplementedException(ConstructMessage(msg, params...)) { + explicit NotImplementedException(const string &msg, ARGS &&...params) + : NotImplementedException(ConstructMessage(msg, std::forward(params)...)) { + } + template + explicit NotImplementedException(const unordered_map &extra_info, const string &msg, + ARGS &&...params) + : NotImplementedException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; @@ -273,8 +284,8 @@ class SerializationException : public Exception { DUCKDB_API explicit SerializationException(const string &msg); template - explicit SerializationException(const string &msg, ARGS... params) - : SerializationException(ConstructMessage(msg, params...)) { + explicit SerializationException(const string &msg, ARGS &&...params) + : SerializationException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -283,8 +294,8 @@ class SequenceException : public Exception { DUCKDB_API explicit SequenceException(const string &msg); template - explicit SequenceException(const string &msg, ARGS... params) - : SequenceException(ConstructMessage(msg, params...)) { + explicit SequenceException(const string &msg, ARGS &&...params) + : SequenceException(ConstructMessage(msg, std::forward(params)...)) { } }; @@ -298,39 +309,48 @@ class FatalException : public Exception { explicit FatalException(const string &msg) : FatalException(ExceptionType::FATAL, msg) { } template - explicit FatalException(const string &msg, ARGS... params) : FatalException(ConstructMessage(msg, params...)) { + explicit FatalException(const string &msg, ARGS &&...params) + : FatalException(ConstructMessage(msg, std::forward(params)...)) { } protected: DUCKDB_API explicit FatalException(ExceptionType type, const string &msg); template - explicit FatalException(ExceptionType type, const string &msg, ARGS... params) - : FatalException(type, ConstructMessage(msg, params...)) { + explicit FatalException(ExceptionType type, const string &msg, ARGS &&...params) + : FatalException(type, ConstructMessage(msg, std::forward(params)...)) { } }; class InternalException : public Exception { public: DUCKDB_API explicit InternalException(const string &msg); + InternalException(const unordered_map &extra_info, const string &msg); template - explicit InternalException(const string &msg, ARGS... params) - : InternalException(ConstructMessage(msg, params...)) { + explicit InternalException(const string &msg, ARGS &&...params) + : InternalException(ConstructMessage(msg, std::forward(params)...)) { + } + + template + explicit InternalException(const unordered_map &extra_info, const string &msg, ARGS &&...params) + : InternalException(extra_info, ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidInputException : public Exception { public: DUCKDB_API explicit InvalidInputException(const string &msg); - DUCKDB_API explicit InvalidInputException(const string &msg, const unordered_map &extra_info); + DUCKDB_API explicit InvalidInputException(const unordered_map &extra_info, const string &msg); template - explicit InvalidInputException(const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...)) { + explicit InvalidInputException(const string &msg, ARGS &&...params) + : InvalidInputException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit InvalidInputException(const Expression &expr, const string &msg, ARGS... params) - : InvalidInputException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidInputException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidInputException(Exception::InitializeExtraInfo(expr), + ConstructMessage(msg, std::forward(params)...)) { } }; @@ -339,24 +359,26 @@ class ExecutorException : public Exception { DUCKDB_API explicit ExecutorException(const string &msg); template - explicit ExecutorException(const string &msg, ARGS... params) - : ExecutorException(ConstructMessage(msg, params...)) { + explicit ExecutorException(const string &msg, ARGS &&...params) + : ExecutorException(ConstructMessage(msg, std::forward(params)...)) { } }; class InvalidConfigurationException : public Exception { public: DUCKDB_API explicit InvalidConfigurationException(const string &msg); - DUCKDB_API explicit InvalidConfigurationException(const string &msg, - const unordered_map &extra_info); + + DUCKDB_API explicit InvalidConfigurationException(const unordered_map &extra_info, + const string &msg); template - explicit InvalidConfigurationException(const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...)) { + explicit InvalidConfigurationException(const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS... params) - : InvalidConfigurationException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit InvalidConfigurationException(const Expression &expr, const string &msg, ARGS &&...params) + : InvalidConfigurationException(ConstructMessage(msg, std::forward(params)...), + Exception::InitializeExtraInfo(expr)) { } }; @@ -381,8 +403,8 @@ class ParameterNotAllowedException : public Exception { DUCKDB_API explicit ParameterNotAllowedException(const string &msg); template - explicit ParameterNotAllowedException(const string &msg, ARGS... params) - : ParameterNotAllowedException(ConstructMessage(msg, params...)) { + explicit ParameterNotAllowedException(const string &msg, ARGS &&...params) + : ParameterNotAllowedException(ConstructMessage(msg, std::forward(params)...)) { } }; diff --git a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp index 2590cb094..5aec7e296 100644 --- a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp @@ -9,37 +9,46 @@ #pragma once #include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" #include "duckdb/parser/query_error_context.hpp" namespace duckdb { class BinderException : public Exception { public: - DUCKDB_API explicit BinderException(const string &msg, const unordered_map &extra_info); DUCKDB_API explicit BinderException(const string &msg); + DUCKDB_API explicit BinderException(const unordered_map &extra_info, const string &msg); + template - explicit BinderException(const string &msg, ARGS... params) : BinderException(ConstructMessage(msg, params...)) { + explicit BinderException(const string &msg, ARGS &&...params) + : BinderException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const TableRef &ref, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(ref)) { + explicit BinderException(const TableRef &ref, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(ref), ConstructMessage(msg, std::forward(params)...)) { } template - explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(const Expression &expr, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit BinderException(const Expression &expr, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit BinderException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } + template - explicit BinderException(optional_idx error_location, const string &msg, ARGS... params) - : BinderException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit BinderException(optional_idx error_location, const string &msg, ARGS &&...params) + : BinderException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } static BinderException ColumnNotFound(const string &name, const vector &similar_bindings, diff --git a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp index 498fafd19..aadbc9f83 100644 --- a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp @@ -9,6 +9,8 @@ #pragma once #include "duckdb/common/exception.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/string.hpp" #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/parser/query_error_context.hpp" #include "duckdb/common/unordered_map.hpp" @@ -19,14 +21,18 @@ struct EntryLookupInfo; class CatalogException : public Exception { public: DUCKDB_API explicit CatalogException(const string &msg); - DUCKDB_API explicit CatalogException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit CatalogException(const unordered_map &extra_info, const string &msg); template - explicit CatalogException(const string &msg, ARGS... params) : CatalogException(ConstructMessage(msg, params...)) { + explicit CatalogException(const string &msg, ARGS &&...params) + : CatalogException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS... params) - : CatalogException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_context)) { + explicit CatalogException(QueryErrorContext error_context, const string &msg, ARGS &&...params) + : CatalogException(Exception::InitializeExtraInfo(error_context), + ConstructMessage(msg, std::forward(params)...)) { } static CatalogException MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion); diff --git a/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp index 5330f46e6..9252d0790 100644 --- a/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/conversion_exception.hpp @@ -12,22 +12,24 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { - class ConversionException : public Exception { public: DUCKDB_API explicit ConversionException(const string &msg); + DUCKDB_API explicit ConversionException(optional_idx error_location, const string &msg); + DUCKDB_API ConversionException(const PhysicalType orig_type, const PhysicalType new_type); + DUCKDB_API ConversionException(const LogicalType &orig_type, const LogicalType &new_type); template - explicit ConversionException(const string &msg, ARGS... params) - : ConversionException(ConstructMessage(msg, params...)) { + explicit ConversionException(const string &msg, ARGS &&...params) + : ConversionException(ConstructMessage(msg, std::forward(params)...)) { } + template - explicit ConversionException(optional_idx error_location, const string &msg, ARGS... params) - : ConversionException(error_location, ConstructMessage(msg, params...)) { + explicit ConversionException(optional_idx error_location, const string &msg, ARGS &&...params) + : ConversionException(error_location, ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp index aff00d23d..b0d0e9c2d 100644 --- a/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/http_exception.hpp @@ -24,9 +24,9 @@ class HTTPException : public Exception { } template ::status = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.status), response.body, response.headers, response.reason, msg, - params...) { + std::forward(params)...) { } template @@ -35,16 +35,16 @@ class HTTPException : public Exception { }; template ::code = 0, typename... ARGS> - explicit HTTPException(RESPONSE &response, const string &msg, ARGS... params) + explicit HTTPException(RESPONSE &response, const string &msg, ARGS &&...params) : HTTPException(static_cast(response.code), response.body, response.headers, response.error, msg, - params...) { + std::forward(params)...) { } template explicit HTTPException(int status_code, const string &response_body, const HEADERS &headers, const string &reason, - const string &msg, ARGS... params) - : Exception(ExceptionType::HTTP, ConstructMessage(msg, params...), - HTTPExtraInfo(status_code, response_body, headers, reason)) { + const string &msg, ARGS &&...params) + : Exception(HTTPExtraInfo(status_code, response_body, headers, reason), ExceptionType::HTTP, + ConstructMessage(msg, std::forward(params)...)) { } template diff --git a/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp index 363a34457..26ce6c585 100644 --- a/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/parser_exception.hpp @@ -17,18 +17,21 @@ namespace duckdb { class ParserException : public Exception { public: DUCKDB_API explicit ParserException(const string &msg); - DUCKDB_API explicit ParserException(const string &msg, const unordered_map &extra_info); + + DUCKDB_API explicit ParserException(const unordered_map &extra_info, const string &msg); template - explicit ParserException(const string &msg, ARGS... params) : ParserException(ConstructMessage(msg, params...)) { + explicit ParserException(const string &msg, ARGS &&...params) + : ParserException(ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(optional_idx error_location, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(error_location)) { + explicit ParserException(optional_idx error_location, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(error_location), + ConstructMessage(msg, std::forward(params)...)) { } template - explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS... params) - : ParserException(ConstructMessage(msg, params...), Exception::InitializeExtraInfo(expr)) { + explicit ParserException(const ParsedExpression &expr, const string &msg, ARGS &&...params) + : ParserException(Exception::InitializeExtraInfo(expr), ConstructMessage(msg, std::forward(params)...)) { } static ParserException SyntaxError(const string &query, const string &error_message, optional_idx error_location); diff --git a/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp index f0164df69..5ca0be62b 100644 --- a/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/transaction_exception.hpp @@ -11,15 +11,13 @@ #include "duckdb/common/exception.hpp" namespace duckdb { - class TransactionException : public Exception { public: DUCKDB_API explicit TransactionException(const string &msg); template - explicit TransactionException(const string &msg, ARGS... params) - : TransactionException(ConstructMessage(msg, params...)) { + explicit TransactionException(const string &msg, ARGS &&...params) + : TransactionException(ConstructMessage(msg, std::forward(params)...)) { } }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp index 3693db54c..7beeead6e 100644 --- a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp +++ b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp @@ -49,13 +49,13 @@ enum class ExceptionFormatValueType : uint8_t { }; struct ExceptionFormatValue { - DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT - DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT - DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT - DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(String str_val); // NOLINT - DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT - DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT + DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT + DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT + DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(const String &str_val); // NOLINT + DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT + DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT ExceptionFormatValueType type; @@ -65,37 +65,37 @@ struct ExceptionFormatValue { public: template - static ExceptionFormatValue CreateFormatValue(T value) { + static ExceptionFormatValue CreateFormatValue(const T &value) { return int64_t(value); } static string Format(const string &msg, std::vector &values); }; template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const PhysicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLString value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(SQLIdentifier value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLIdentifier &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(LogicalType value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const LogicalType &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const float &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const double &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(String value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *const &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const idx_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const hugeint_t &value); template <> -DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value); +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const uhugeint_t &value); } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp index d5e35ee96..02348d69b 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -28,7 +28,8 @@ enum class ExtraTypeInfoType : uint8_t { ARRAY_TYPE_INFO = 9, ANY_TYPE_INFO = 10, INTEGER_LITERAL_TYPE_INFO = 11, - TEMPLATE_TYPE_INFO = 12 + TEMPLATE_TYPE_INFO = 12, + GEO_TYPE_INFO = 13 }; struct ExtraTypeInfo { @@ -261,7 +262,6 @@ struct IntegerLiteralTypeInfo : public ExtraTypeInfo { }; struct TemplateTypeInfo : public ExtraTypeInfo { - explicit TemplateTypeInfo(string name_p); // The name of the template, e.g. `T`, or `KEY_TYPE`. Used to distinguish between different template types within @@ -278,4 +278,16 @@ struct TemplateTypeInfo : public ExtraTypeInfo { TemplateTypeInfo(); }; +struct GeoTypeInfo : public ExtraTypeInfo { +public: + GeoTypeInfo(); + + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + shared_ptr Copy() const override; + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/file_buffer.hpp b/src/duckdb/src/include/duckdb/common/file_buffer.hpp index f330854c2..e0fcfcdec 100644 --- a/src/duckdb/src/include/duckdb/common/file_buffer.hpp +++ b/src/duckdb/src/include/duckdb/common/file_buffer.hpp @@ -13,7 +13,7 @@ namespace duckdb { -class Allocator; +class BlockAllocator; class BlockManager; class QueryContext; @@ -30,13 +30,13 @@ class FileBuffer { //! (typically 8 bytes). On return, this->AllocSize() >= this->size >= user_size. //! Our allocation size will always be page-aligned, which is necessary to support //! DIRECT_IO - FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size, idx_t block_header_size); - FileBuffer(Allocator &allocator, FileBufferType type, BlockManager &block_manager); + FileBuffer(BlockAllocator &allocator, FileBufferType type, uint64_t user_size, idx_t block_header_size); + FileBuffer(BlockAllocator &allocator, FileBufferType type, BlockManager &block_manager); FileBuffer(FileBuffer &source, FileBufferType type, idx_t block_header_size); virtual ~FileBuffer(); - Allocator &allocator; + BlockAllocator &allocator; //! The buffer that users can write to data_ptr_t buffer; //! The user-facing size of the buffer. diff --git a/src/duckdb/src/include/duckdb/common/file_system.hpp b/src/duckdb/src/include/duckdb/common/file_system.hpp index 4e95f5ce5..54dca7d4f 100644 --- a/src/duckdb/src/include/duckdb/common/file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/file_system.hpp @@ -10,16 +10,19 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/enums/file_compression_type.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_buffer.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/vector.hpp" #include "duckdb/common/enums/file_glob_options.hpp" -#include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/exception.hpp" #include "duckdb/common/error_data.hpp" +#include "duckdb/common/file_buffer.hpp" #include "duckdb/common/file_open_flags.hpp" #include "duckdb/common/open_file_info.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/vector.hpp" + #include #undef CreateDirectory @@ -55,6 +58,15 @@ enum class FileType { FILE_TYPE_INVALID, }; +struct FileMetadata { + int64_t file_size = -1; + timestamp_t last_modification_time = timestamp_t::ninfinity(); + FileType file_type = FileType::FILE_TYPE_INVALID; + + // A key-value pair of the extended file metadata, which could store any attributes. + unordered_map extended_file_info; +}; + struct FileHandle { public: DUCKDB_API FileHandle(FileSystem &file_system, string path, FileOpenFlags flags); @@ -87,6 +99,7 @@ struct FileHandle { DUCKDB_API bool OnDiskFile(); DUCKDB_API idx_t GetFileSize(); DUCKDB_API FileType GetType(); + DUCKDB_API FileMetadata Stats(); DUCKDB_API void TryAddLogger(FileOpener &opener); @@ -158,6 +171,8 @@ class FileSystem { DUCKDB_API virtual string GetVersionTag(FileHandle &handle); //! Returns the file type of the attached handle DUCKDB_API virtual FileType GetFileType(FileHandle &handle); + //! Returns the file stats of the attached handle. + DUCKDB_API virtual FileMetadata Stats(FileHandle &handle); //! Truncate a file to a maximum size of new_size, new_size should be smaller than or equal to the current size of //! the file DUCKDB_API virtual void Truncate(FileHandle &handle, int64_t new_size); @@ -282,6 +297,8 @@ class FileSystem { DUCKDB_API virtual void SetDisabledFileSystems(const vector &names); DUCKDB_API virtual bool SubSystemIsDisabled(const string &name); + //! Check if the filesystem that would handle this path is disabled + DUCKDB_API virtual bool IsDisabledForPath(const string &path); DUCKDB_API static bool IsDirectory(const OpenFileInfo &info); diff --git a/src/duckdb/src/include/duckdb/common/helper.hpp b/src/duckdb/src/include/duckdb/common/helper.hpp index d5fb4b465..118bada1e 100644 --- a/src/duckdb/src/include/duckdb/common/helper.hpp +++ b/src/duckdb/src/include/duckdb/common/helper.hpp @@ -136,20 +136,6 @@ shared_ptr shared_ptr_cast(shared_ptr src) { // NOLINT: mimic std styl return shared_ptr(std::static_pointer_cast(src.internal)); } -struct SharedConstructor { - template - static shared_ptr Create(ARGS &&...args) { - return make_shared_ptr(std::forward(args)...); - } -}; - -struct UniqueConstructor { - template - static unique_ptr Create(ARGS &&...args) { - return make_uniq(std::forward(args)...); - } -}; - #ifdef DUCKDB_DEBUG_MOVE template typename std::remove_reference::type&& move(T&& t) noexcept { diff --git a/src/duckdb/src/include/duckdb/common/http_util.hpp b/src/duckdb/src/include/duckdb/common/http_util.hpp index 51127179d..11fc26c48 100644 --- a/src/duckdb/src/include/duckdb/common/http_util.hpp +++ b/src/duckdb/src/include/duckdb/common/http_util.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/types.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/http_status_code.hpp" +#include "duckdb/common/types/timestamp.hpp" #include namespace duckdb { @@ -143,6 +144,11 @@ struct BaseRequest { //! Whether or not to return failed requests (instead of throwing) bool try_request = false; + // Requests will optionally contain their timings + bool have_request_timing = false; + timestamp_t request_start; + timestamp_t request_end; + template TARGET &Cast() { return reinterpret_cast(*this); @@ -210,6 +216,7 @@ struct PostRequestInfo : public BaseRequest { class HTTPClient { public: virtual ~HTTPClient() = default; + virtual void Initialize(HTTPParams &http_params) = 0; virtual unique_ptr Get(GetRequestInfo &info) = 0; virtual unique_ptr Put(PutRequestInfo &info) = 0; diff --git a/src/duckdb/src/include/duckdb/common/hugeint.hpp b/src/duckdb/src/include/duckdb/common/hugeint.hpp index c9b54bd95..acdc4fb4b 100644 --- a/src/duckdb/src/include/duckdb/common/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/hugeint.hpp @@ -76,7 +76,7 @@ struct hugeint_t { // NOLINT: use numeric casing DUCKDB_API explicit operator int16_t() const; DUCKDB_API explicit operator int32_t() const; DUCKDB_API explicit operator int64_t() const; - DUCKDB_API operator uhugeint_t() const; // NOLINT: Allow implicit conversion from `hugeint_t` + DUCKDB_API explicit operator uhugeint_t() const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp index 7240d62d8..5b6f11730 100644 --- a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp +++ b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp @@ -95,6 +95,10 @@ class InsertionOrderPreservingMap { map.resize(nz); } + void clear() { // NOLINT: match stl API + map.clear(); + } + void insert(const string &key, V &&value) { // NOLINT: match stl API if (contains(key)) { return; diff --git a/src/duckdb/src/include/duckdb/common/limits.hpp b/src/duckdb/src/include/duckdb/common/limits.hpp index 0662579ef..67a98daf0 100644 --- a/src/duckdb/src/include/duckdb/common/limits.hpp +++ b/src/duckdb/src/include/duckdb/common/limits.hpp @@ -24,10 +24,12 @@ namespace duckdb { template struct NumericLimits { static constexpr T Minimum() { - return std::numeric_limits::lowest(); + return std::numeric_limits::has_infinity ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); } static constexpr T Maximum() { - return std::numeric_limits::max(); + return std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); } static constexpr bool IsSigned() { return std::is_signed::value; diff --git a/src/duckdb/src/include/duckdb/common/local_file_system.hpp b/src/duckdb/src/include/duckdb/common/local_file_system.hpp index 8b3f7aaf2..8d941475e 100644 --- a/src/duckdb/src/include/duckdb/common/local_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/local_file_system.hpp @@ -38,8 +38,12 @@ class LocalFileSystem : public FileSystem { int64_t GetFileSize(FileHandle &handle) override; //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error timestamp_t GetLastModifiedTime(FileHandle &handle) override; + //! Returns a tag that uniquely identifies the version of the file + string GetVersionTag(FileHandle &handle) override; //! Returns the file last modified time of a file handle, returns timespec with zero on all attributes on error FileType GetFileType(FileHandle &handle) override; + //! Returns the file stats of the attached handle. + FileMetadata Stats(FileHandle &handle) override; //! Truncate a file to a maximum size of new_size, new_size should be smaller than or equal to the current size of //! the file void Truncate(FileHandle &handle, int64_t new_size) override; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp index f0a29c7af..c9ed4da21 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp @@ -79,8 +79,8 @@ class BaseFileReader : public enable_shared_from_this { virtual bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) = 0; //! Scan a chunk from the read state - virtual void Scan(ClientContext &context, GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state, DataChunk &chunk) = 0; + virtual AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) = 0; //! Finish scanning a given file DUCKDB_API virtual void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate); //! Get progress within a given file diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp index fd6380a7e..523084d6e 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp @@ -139,7 +139,7 @@ struct MultiFileLocalColumnId { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return column_id; } idx_t GetId() const { @@ -170,7 +170,7 @@ struct MultiFileLocalIndex { } public: - operator idx_t() { // NOLINT: allow implicit conversion + operator idx_t() const { // NOLINT: allow implicit conversion return index; } idx_t GetIndex() const { diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp index 1ed169568..9cebe4fc8 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp @@ -590,28 +590,77 @@ class MultiFileFunction : public TableFunction { static void MultiFileScan(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { if (!data_p.local_state) { + data_p.async_result = SourceResultType::FINISHED; return; } auto &data = data_p.local_state->Cast(); auto &gstate = data_p.global_state->Cast(); auto &bind_data = data_p.bind_data->CastNoConst(); + if (gstate.finished) { + data_p.async_result = SourceResultType::FINISHED; + return; + } + do { auto &scan_chunk = data.scan_chunk; scan_chunk.Reset(); - data.reader->Scan(context, *gstate.global_state, *data.local_state, scan_chunk); + auto res = data.reader->Scan(context, *gstate.global_state, *data.local_state, scan_chunk); + + if (res.GetResultType() == AsyncResultType::BLOCKED) { + if (scan_chunk.size() != 0) { + throw InternalException("Unexpected behaviour from Scan, no rows should be returned"); + } + switch (data_p.results_execution_mode) { + case AsyncResultsExecutionMode::TASK_EXECUTOR: + data_p.async_result = std::move(res); + return; + case AsyncResultsExecutionMode::SYNCHRONOUS: + res.ExecuteTasksSynchronously(); + if (res.GetResultType() != AsyncResultType::HAVE_MORE_OUTPUT) { + throw InternalException("Unexpected behaviour from ExecuteTasksSynchronously"); + } + // scan_chunk.size() is 0, see check above, and result is HAVE_MORE_OUTPUT, we need to loop again + continue; + } + } + output.SetCardinality(scan_chunk.size()); + if (scan_chunk.size() > 0) { bind_data.multi_file_reader->FinalizeChunk(context, bind_data, *data.reader, *data.reader_data, scan_chunk, output, data.executor, gstate.multi_file_reader_state); + } + if (res.GetResultType() == AsyncResultType::HAVE_MORE_OUTPUT) { + // Loop back to the same block + if (scan_chunk.size() == 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + continue; + } + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; return; } - scan_chunk.Reset(); + + if (res.GetResultType() != AsyncResultType::FINISHED) { + throw InternalException("Unexpected result in MultiFileScan, must be FINISHED, is %s", + EnumUtil::ToChars(res.GetResultType())); + } + if (!TryInitializeNextBatch(context, bind_data, data, gstate)) { - return; + if (scan_chunk.size() > 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + gstate.finished = true; + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; + } else { + data_p.async_result = SourceResultType::FINISHED; + } + } else { + if (scan_chunk.size() == 0 && data_p.results_execution_mode == AsyncResultsExecutionMode::SYNCHRONOUS) { + continue; + } + data_p.async_result = SourceResultType::HAVE_MORE_OUTPUT; } + return; } while (true); } @@ -672,7 +721,8 @@ class MultiFileFunction : public TableFunction { continue; } auto &reader_data = *reader_data_ptr; - double progress_in_file; + // Initialize progress_in_file with a default value to avoid uninitialized variable usage + double progress_in_file = 0.0; if (reader_data.file_state == MultiFileFileState::OPEN) { // file is currently open - get the progress within the file progress_in_file = reader_data.reader->GetProgressInFile(context); @@ -686,9 +736,6 @@ class MultiFileFunction : public TableFunction { // file is still being read progress_in_file = reader->GetProgressInFile(context); } - } else { - // file has not been opened yet - progress in this file is zero - progress_in_file = 0; } progress_in_file = MaxValue(0.0, MinValue(100.0, progress_in_file)); total_progress += progress_in_file; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp index d9801e5d0..556a33d3b 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp @@ -166,6 +166,7 @@ struct MultiFileGlobalState : public GlobalTableFunctionState { vector scanned_types; vector column_indexes; optional_ptr filters; + atomic finished {false}; unique_ptr global_state; diff --git a/src/duckdb/src/include/duckdb/common/opener_file_system.hpp b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp index 519c2bc71..d48e3c25e 100644 --- a/src/duckdb/src/include/duckdb/common/opener_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp @@ -50,6 +50,9 @@ class OpenerFileSystem : public FileSystem { FileType GetFileType(FileHandle &handle) override { return GetFileSystem().GetFileType(handle); } + FileMetadata Stats(FileHandle &handle) override { + return GetFileSystem().Stats(handle); + } void Truncate(FileHandle &handle, int64_t new_size) override { GetFileSystem().Truncate(handle, new_size); @@ -147,6 +150,10 @@ class OpenerFileSystem : public FileSystem { return GetFileSystem().SubSystemIsDisabled(name); } + bool IsDisabledForPath(const string &path) override { + return GetFileSystem().IsDisabledForPath(path); + } + vector ListSubSystems() override { return GetFileSystem().ListSubSystems(); } diff --git a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp index ac55e5a69..e495e9760 100644 --- a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp @@ -1070,6 +1070,19 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, Vector &res template <> bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict); +//===--------------------------------------------------------------------===// +// GEOMETRY +//===--------------------------------------------------------------------===// +struct TryCastToGeometry { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, CastParameters ¶meters) { + throw InternalException("Unsupported type for try cast to geometry"); + } +}; + +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters); + //===--------------------------------------------------------------------===// // Pointers //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp index f1a6f6eb3..a847e217b 100644 --- a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp @@ -210,15 +210,4 @@ inline bool GreaterThan::Operation(const interval_t &left, const interval_t &rig return Interval::GreaterThan(left, right); } -//===--------------------------------------------------------------------===// -// Specialized Hugeint Comparison Operators -//===--------------------------------------------------------------------===// -template <> -inline bool Equals::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::Equals(left, right); -} -template <> -inline bool GreaterThan::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::GreaterThan(left, right); -} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/integer_cast_operator.hpp b/src/duckdb/src/include/duckdb/common/operator/integer_cast_operator.hpp index 8bf694c42..ccaef767c 100644 --- a/src/duckdb/src/include/duckdb/common/operator/integer_cast_operator.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/integer_cast_operator.hpp @@ -103,9 +103,12 @@ struct IntegerDecimalCastOperation : IntegerCastOperation { int16_t e = exponent; // Negative Exponent if (e < 0) { - while (state.result != 0 && e++ < 0) { + while (e++ < 0) { state.decimal = state.result % 10; state.result /= 10; + if (state.result == 0 && state.decimal == 0) { + break; + } } if (state.decimal < 0) { state.decimal = -state.decimal; diff --git a/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp b/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp index 939b570ec..d7147bc0d 100644 --- a/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/numeric_cast.hpp @@ -32,7 +32,7 @@ static bool TryCastWithOverflowCheck(SRC value, DST &result) { if (NumericLimits::IsSigned()) { // signed to unsigned conversion if (NumericLimits::Digits() > NumericLimits::Digits()) { - if (value < 0 || value > (SRC)NumericLimits::Maximum()) { + if (value < 0 || value > static_cast(NumericLimits::Maximum())) { return false; } } else { @@ -40,31 +40,31 @@ static bool TryCastWithOverflowCheck(SRC value, DST &result) { return false; } } - result = (DST)value; + result = static_cast(value); return true; } else { // unsigned to signed conversion if (NumericLimits::Digits() >= NumericLimits::Digits()) { - if (value <= (SRC)NumericLimits::Maximum()) { - result = (DST)value; + if (value <= static_cast(NumericLimits::Maximum())) { + result = static_cast(value); return true; } return false; } else { - result = (DST)value; + result = static_cast(value); return true; } } } else { // same sign conversion if (NumericLimits::Digits() >= NumericLimits::Digits()) { - result = (DST)value; + result = static_cast(value); return true; } else { if (value < SRC(NumericLimits::Minimum()) || value > SRC(NumericLimits::Maximum())) { return false; } - result = (DST)value; + result = static_cast(value); return true; } } diff --git a/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp b/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp index 5fb6ba1fd..e300435a2 100644 --- a/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/pipe_file_system.hpp @@ -20,6 +20,7 @@ class PipeFileSystem : public FileSystem { int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; int64_t GetFileSize(FileHandle &handle) override; + timestamp_t GetLastModifiedTime(FileHandle &handle) override; void Reset(FileHandle &handle) override; bool OnDiskFile(FileHandle &handle) override { diff --git a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp index 1376cc1b5..6d4c79194 100644 --- a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp +++ b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp @@ -19,6 +19,14 @@ struct PrimitiveCastOperator { static TGT Operation(SRC input) { return TGT(input); } + template + static constexpr idx_t WriteSize(const TGT &input) { + return sizeof(TGT); + } + template + static void WriteToStream(const TGT &input, WriteStream &ser) { + ser.WriteData(const_data_ptr_cast(&input), sizeof(TGT)); + } }; template @@ -51,21 +59,19 @@ class PrimitiveDictionary { : capacity * sizeof(TGT))), target_stream(allocated_target.get(), allocated_target.GetSize()), dictionary(reinterpret_cast(allocated_dictionary.get())), full(false) { - // Initialize empty - for (idx_t i = 0; i < capacity; i++) { - dictionary[i].index = INVALID_INDEX; - } + Clear(); } public: //! Insert value into dictionary (if not full) + template void Insert(SRC value) { if (full) { return; } auto &entry = Lookup(value); if (entry.IsEmpty()) { - if (size + 1 > maximum_size || !AddToTarget(value)) { + if (size + 1 > maximum_size || (ADD_TO_TARGET && !AddToTarget(value))) { full = true; return; } @@ -128,7 +134,13 @@ class PrimitiveDictionary { allocated_target.Reset(); } -private: + void Clear() { + for (idx_t i = 0; i < capacity; i++) { + dictionary[i].index = INVALID_INDEX; + } + size = 0; + full = false; + } //! Look up a value in the dictionary using linear probing primitive_dictionary_entry_t &Lookup(const SRC &value) const { auto offset = Hash(value) & capacity_mask; @@ -138,6 +150,7 @@ class PrimitiveDictionary { return dictionary[offset]; } +private: //! Write a value to the target data (if source is not string) template ::value, int>::type = 0> bool AddToTarget(const SRC &src_value) { @@ -205,7 +218,7 @@ class PrimitiveDictionary { //! Maximum size and current size const idx_t maximum_size; - idx_t size; + uint32_t size; //! Dictionary capacity (power of two) and corresponding mask const idx_t capacity; diff --git a/src/duckdb/src/include/duckdb/common/profiler.hpp b/src/duckdb/src/include/duckdb/common/profiler.hpp index 5fb65337a..ca682720f 100644 --- a/src/duckdb/src/include/duckdb/common/profiler.hpp +++ b/src/duckdb/src/include/duckdb/common/profiler.hpp @@ -13,36 +13,60 @@ namespace duckdb { -//! The profiler can be used to measure elapsed time +//! Profiler class to measure the elapsed time. template class BaseProfiler { public: - //! Starts the timer + //! Start the timer. void Start() { finished = false; + ran = true; start = Tick(); } - //! Finishes timing + //! End the timer. void End() { end = Tick(); finished = true; } + //! Reset the timer. + void Reset() { + finished = false; + ran = false; + } - //! Returns the elapsed time in seconds. If End() has been called, returns - //! the total elapsed time. Otherwise returns how far along the timer is - //! right now. + //! Returns the elapsed time in seconds. + //! If ran is false, it returns 0. + //! If End() has been called, it returns the total elapsed time, + //! otherwise, returns how far along the timer is right now. double Elapsed() const { + if (!ran) { + return 0; + } auto measured_end = finished ? end : Tick(); return std::chrono::duration_cast>(measured_end - start).count(); } + idx_t ElapsedNanos() const { + if (!ran) { + return 0; + } + auto measured_end = finished ? end : Tick(); + return static_cast(std::chrono::duration_cast(measured_end - start).count()); + } + private: + //! Current time point. time_point Tick() const { return T::now(); } + //! Start time point. time_point start; + //! End time point. time_point end; + //! True, if end End() been called. bool finished = false; + //! True, if the timer was ran. + bool ran = false; }; using Profiler = BaseProfiler; diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp index ae6be8549..a4ad9444c 100644 --- a/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp +++ b/src/duckdb/src/include/duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp @@ -13,7 +13,6 @@ #include "duckdb/common/unicode_bar.hpp" #include "duckdb/common/progress_bar/unscented_kalman_filter.hpp" #include -#include namespace duckdb { @@ -30,21 +29,26 @@ struct TerminalProgressBarDisplayedProgressInfo { } }; -class TerminalProgressBarDisplay : public ProgressBarDisplay { -private: - UnscentedKalmanFilter ukf; - std::chrono::steady_clock::time_point start_time; - - // track the progress info that has been previously - // displayed to prevent redundant updates - struct TerminalProgressBarDisplayedProgressInfo displayed_progress_info; - - double GetElapsedDuration() { - auto now = std::chrono::steady_clock::now(); - return std::chrono::duration(now - start_time).count(); - } - void StopPeriodicUpdates(); +struct ProgressBarDisplayInfo { + idx_t width = 38; +#ifndef DUCKDB_ASCII_TREE_RENDERER + const char *progress_empty = " "; + const char *const *progress_partial = UnicodeBar::PartialBlocks(); + idx_t partial_block_count = UnicodeBar::PartialBlocksCount(); + const char *progress_block = UnicodeBar::FullBlock(); + const char *progress_start = "\xE2\x96\x95"; + const char *progress_end = "\xE2\x96\x8F"; +#else + const char *progress_empty = " "; + const char *const progress_partial[PARTIAL_BLOCK_COUNT] = {" ", " ", " ", " ", " ", " ", " ", " "}; + idx_t partial_block_count = 8; + const char *progress_block = "="; + const char *progress_start = "["; + const char *progress_end = "]"; +#endif +}; +class TerminalProgressBarDisplay : public ProgressBarDisplay { public: TerminalProgressBarDisplay() { start_time = std::chrono::steady_clock::now(); @@ -57,32 +61,33 @@ class TerminalProgressBarDisplay : public ProgressBarDisplay { public: void Update(double percentage) override; void Finish() override; + static string FormatETA(double seconds, bool elapsed = false); + static string FormatProgressBar(const ProgressBarDisplayInfo &display_info, int32_t percentage); private: - std::mutex mtx; - std::thread periodic_update_thread; - std::condition_variable cv; void PeriodicUpdate(); - static constexpr const idx_t PARTIAL_BLOCK_COUNT = UnicodeBar::PartialBlocksCount(); -#ifndef DUCKDB_ASCII_TREE_RENDERER - const char *PROGRESS_EMPTY = " "; // NOLINT - const char *const *PROGRESS_PARTIAL = UnicodeBar::PartialBlocks(); // NOLINT - const char *PROGRESS_BLOCK = UnicodeBar::FullBlock(); // NOLINT - const char *PROGRESS_START = "\xE2\x96\x95"; // NOLINT - const char *PROGRESS_END = "\xE2\x96\x8F"; // NOLINT -#else - const char *PROGRESS_EMPTY = " "; - const char *const PROGRESS_PARTIAL[PARTIAL_BLOCK_COUNT] = {" ", " ", " ", " ", " ", " ", " ", " "}; - const char *PROGRESS_BLOCK = "="; - const char *PROGRESS_START = "["; - const char *PROGRESS_END = "]"; -#endif - static constexpr const idx_t PROGRESS_BAR_WIDTH = 38; +public: + ProgressBarDisplayInfo display_info; + +protected: + virtual void PrintProgressInternal(int32_t percentage, double estimated_remaining_seconds, + bool is_finished = false); -private: static int32_t NormalizePercentage(double percentage); - void PrintProgressInternal(int32_t percentage, double estimated_remaining_seconds, bool is_finished = false); + double GetElapsedDuration() { + auto now = std::chrono::steady_clock::now(); + return std::chrono::duration(now - start_time).count(); + } + void StopPeriodicUpdates(); + +private: + UnscentedKalmanFilter ukf; + std::chrono::steady_clock::time_point start_time; + + // track the progress info that has been previously + // displayed to prevent redundant updates + struct TerminalProgressBarDisplayedProgressInfo displayed_progress_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/queue.hpp b/src/duckdb/src/include/duckdb/common/queue.hpp index d3e28d982..e768490cc 100644 --- a/src/duckdb/src/include/duckdb/common/queue.hpp +++ b/src/duckdb/src/include/duckdb/common/queue.hpp @@ -8,8 +8,77 @@ #pragma once +#include "duckdb/common/assert.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/memory_safety.hpp" #include namespace duckdb { -using std::queue; + +template , bool SAFE = true> +class queue : public std::queue { // NOLINT: matching name of std +public: + using original = std::queue; + using original::original; + using container_type = typename original::container_type; + using value_type = typename original::value_type; + using size_type = typename container_type::size_type; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + +public: + // Because we create the other constructor, the implicitly created constructor + // gets deleted, so we have to be explicit + queue() = default; + queue(original &&other) : original(std::move(other)) { // NOLINT: allow implicit conversion + } + template + queue(queue &&other) : original(std::move(other)) { // NOLINT + } + + inline void clear() noexcept { + original::c.clear(); + } + + reference front() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty queue!"); + } + return original::front(); + } + + const_reference front() const { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'front' called on an empty queue!"); + } + return original::front(); + } + + reference back() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty queue!"); + } + return original::back(); + } + + const_reference back() const { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'back' called on an empty queue!"); + } + return original::back(); + } + + void pop() { + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'pop' called on an empty queue!"); + } + original::pop(); + } +}; + +template > +using unsafe_queue = queue; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/radix.hpp b/src/duckdb/src/include/duckdb/common/radix.hpp index dc22256c5..bdf8b8f62 100644 --- a/src/duckdb/src/include/duckdb/common/radix.hpp +++ b/src/duckdb/src/include/duckdb/common/radix.hpp @@ -24,15 +24,6 @@ namespace duckdb { struct Radix { public: - static inline bool IsLittleEndian() { - int n = 1; - if (*char_ptr_cast(&n) == 1) { - return true; - } else { - return false; - } - } - template static inline void EncodeData(data_ptr_t dataptr, T value) { throw NotImplementedException("Cannot create data from this type"); @@ -177,7 +168,7 @@ void Radix::EncodeSigned(data_ptr_t dataptr, T value) { using UNSIGNED = typename MakeUnsigned::type; UNSIGNED bytes; Store(value, data_ptr_cast(&bytes)); - Store(BSwap(bytes), dataptr); + Store(BSwapIfLE(bytes), dataptr); dataptr[0] = FlipSign(dataptr[0]); } @@ -208,17 +199,17 @@ inline void Radix::EncodeData(data_ptr_t dataptr, uint8_t value) { template <> inline void Radix::EncodeData(data_ptr_t dataptr, uint16_t value) { - Store(BSwap(value), dataptr); + Store(BSwapIfLE(value), dataptr); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, uint32_t value) { - Store(BSwap(value), dataptr); + Store(BSwapIfLE(value), dataptr); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, uint64_t value) { - Store(BSwap(value), dataptr); + Store(BSwapIfLE(value), dataptr); } template <> @@ -236,13 +227,13 @@ inline void Radix::EncodeData(data_ptr_t dataptr, uhugeint_t value) { template <> inline void Radix::EncodeData(data_ptr_t dataptr, float value) { uint32_t converted_value = EncodeFloat(value); - Store(BSwap(converted_value), dataptr); + Store(BSwapIfLE(converted_value), dataptr); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, double value) { uint64_t converted_value = EncodeDouble(value); - Store(BSwap(converted_value), dataptr); + Store(BSwapIfLE(converted_value), dataptr); } template <> @@ -266,7 +257,7 @@ T Radix::DecodeSigned(const_data_ptr_t input) { auto bytes_data = data_ptr_cast(&bytes); bytes_data[0] = FlipSign(bytes_data[0]); T result; - Store(BSwap(bytes), data_ptr_cast(&result)); + Store(BSwapIfLE(bytes), data_ptr_cast(&result)); return result; } @@ -297,17 +288,17 @@ inline uint8_t Radix::DecodeData(const_data_ptr_t input) { template <> inline uint16_t Radix::DecodeData(const_data_ptr_t input) { - return BSwap(Load(input)); + return BSwapIfLE(Load(input)); } template <> inline uint32_t Radix::DecodeData(const_data_ptr_t input) { - return BSwap(Load(input)); + return BSwapIfLE(Load(input)); } template <> inline uint64_t Radix::DecodeData(const_data_ptr_t input) { - return BSwap(Load(input)); + return BSwapIfLE(Load(input)); } template <> @@ -328,12 +319,12 @@ inline uhugeint_t Radix::DecodeData(const_data_ptr_t input) { template <> inline float Radix::DecodeData(const_data_ptr_t input) { - return DecodeFloat(BSwap(Load(input))); + return DecodeFloat(BSwapIfLE(Load(input))); } template <> inline double Radix::DecodeData(const_data_ptr_t input) { - return DecodeDouble(BSwap(Load(input))); + return DecodeDouble(BSwapIfLE(Load(input))); } template <> diff --git a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp index fef1847bd..5740aae64 100644 --- a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp +++ b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp @@ -107,9 +107,9 @@ class RadixPartitionedColumnData : public PartitionedColumnData { //! RadixPartitionedTupleData is a PartitionedTupleData that partitions input based on the radix of a hash class RadixPartitionedTupleData : public PartitionedTupleData { public: - RadixPartitionedTupleData(BufferManager &buffer_manager, shared_ptr layout_ptr, idx_t radix_bits_p, - idx_t hash_col_idx_p); - RadixPartitionedTupleData(const RadixPartitionedTupleData &other); + RadixPartitionedTupleData(BufferManager &buffer_manager, shared_ptr layout_ptr, MemoryTag tag, + idx_t radix_bits_p, idx_t hash_col_idx_p); + RadixPartitionedTupleData(RadixPartitionedTupleData &other); ~RadixPartitionedTupleData() override; idx_t GetRadixBits() const { diff --git a/src/duckdb/src/include/duckdb/common/random_engine.hpp b/src/duckdb/src/include/duckdb/common/random_engine.hpp index 8a5a3097e..ec14b42e5 100644 --- a/src/duckdb/src/include/duckdb/common/random_engine.hpp +++ b/src/duckdb/src/include/duckdb/common/random_engine.hpp @@ -38,6 +38,8 @@ class RandomEngine { void SetSeed(uint64_t seed); + void RandomData(duckdb::data_ptr_t data, duckdb::idx_t len); + static RandomEngine &Get(ClientContext &context); mutex lock; diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp index 557e9cd5b..ee1d11afb 100644 --- a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -25,22 +25,6 @@ struct SelectionVector; class StringHeap; struct UnifiedVectorFormat; -// The NestedValidity class help to set/get the validity from inside nested vectors -class NestedValidity { - data_ptr_t list_validity_location; - data_ptr_t *struct_validity_locations; - idx_t entry_idx; - idx_t idx_in_entry; - idx_t list_validity_offset; - -public: - explicit NestedValidity(data_ptr_t validitymask_location); - NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index); - void SetInvalid(idx_t idx); - bool IsValid(idx_t idx); - void OffsetListBy(idx_t offset); -}; - struct RowOperationsState { explicit RowOperationsState(ArenaAllocator &allocator) : allocator(allocator) { } @@ -49,7 +33,7 @@ struct RowOperationsState { unique_ptr addresses; // Re-usable vector for row_aggregate.cpp }; -// RowOperations contains a set of operations that operate on data using a RowLayout +// RowOperations contains a set of operations that operate on data using a TupleDataLayout struct RowOperations { //===--------------------------------------------------------------------===// // Aggregation Operators @@ -70,66 +54,6 @@ struct RowOperations { //! finalize - unaligned addresses, updated static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, idx_t aggr_idx); - - //===--------------------------------------------------------------------===// - // Read/Write Operators - //===--------------------------------------------------------------------===// - //! Scatter group data to the rows. Initialises the ValidityMask. - static void Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count); - //! Gather a single column. - //! If heap_ptr is not null, then the data is assumed to contain swizzled pointers, - //! which will be unswizzled in memory. - static void Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size = 0, - data_ptr_t heap_ptr = nullptr); - - //===--------------------------------------------------------------------===// - // Heap Operators - //===--------------------------------------------------------------------===// - //! Compute the entry sizes of a vector with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset = 0); - //! Compute the entry sizes of vector data with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset = 0); - //! Scatter vector with variable size type to the heap. - static void HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset = 0); - //! Scatter vector data with variable size type to the heap. - static void HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset = 0); - //! Gather a single column with variable size type from the heap. - static void HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t key_locations[], - optional_ptr parent_validity); - - //===--------------------------------------------------------------------===// - // Sorting Operators - //===--------------------------------------------------------------------===// - //! Scatter vector data to the rows in radix-sortable format. - static void RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t key_locations[], bool desc, bool has_null, bool nulls_first, idx_t prefix_len, - idx_t width, idx_t offset = 0); - - //===--------------------------------------------------------------------===// - // Out-of-Core Operators - //===--------------------------------------------------------------------===// - //! Swizzles blob pointers to offset within heap row - static void SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count); - //! Swizzles the base pointer of each row to offset within heap block - static void SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset = 0); - //! Copies 'count' heap rows that are pointed to by the rows at 'row_ptr' to 'heap_ptr' and swizzles the pointers - static void CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count); - - //! Unswizzles the base offset within heap block the rows to pointers - static void UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); - //! Unswizzles all offsets back to pointers - static void UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp index 633e74fa1..bb7bceddf 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp index f30cf5790..5248b04a6 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/encoding_util.hpp @@ -14,7 +14,6 @@ namespace duckdb { struct EncodingUtil { - // Encode unsigned integer, returns the number of bytes written template static idx_t EncodeUnsignedLEB128(data_ptr_t target, T value) { diff --git a/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp b/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp index c53bd3004..b08b34471 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/memory_stream.hpp @@ -50,6 +50,7 @@ class MemoryStream : public WriteStream, public ReadStream { //! Write data to the stream. //! Throws if the write would exceed the capacity of the stream and the backing buffer is not owned by the stream void WriteData(const_data_ptr_t buffer, idx_t write_size) override; + void GrowCapacity(idx_t write_size); //! Read data from the stream. //! Throws if the read would exceed the capacity of the stream diff --git a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp index 5bde0f9a1..bdb82b0c9 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -186,7 +186,6 @@ struct is_atomic> : std::true_type { // NOLINTEND struct SerializationDefaultValue { - template static inline typename std::enable_if::value, T>::type GetDefault() { using INNER = typename is_atomic::TYPE; diff --git a/src/duckdb/src/include/duckdb/common/serializer/varint.hpp b/src/duckdb/src/include/duckdb/common/serializer/varint.hpp index 8d0316a32..8cccd6f56 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/varint.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/varint.hpp @@ -35,7 +35,8 @@ uint8_t GetVarintSize(T val) { } template -void VarintEncode(T val, data_ptr_t ptr) { +idx_t VarintEncode(T val, data_ptr_t ptr) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -44,11 +45,14 @@ void VarintEncode(T val, data_ptr_t ptr) { } *ptr = byte; ptr++; + size++; } while (val != 0); + return size; } template -void VarintEncode(T val, MemoryStream &ser) { +idx_t VarintEncode(T val, MemoryStream &ser) { + idx_t size = 0; do { uint8_t byte = val & 127; val >>= 7; @@ -56,7 +60,9 @@ void VarintEncode(T val, MemoryStream &ser) { byte |= 128; } ser.WriteData(&byte, sizeof(uint8_t)); + size++; } while (val != 0); + return size; } } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/shared_ptr.hpp b/src/duckdb/src/include/duckdb/common/shared_ptr.hpp index eff60e3db..49419753e 100644 --- a/src/duckdb/src/include/duckdb/common/shared_ptr.hpp +++ b/src/duckdb/src/include/duckdb/common/shared_ptr.hpp @@ -8,45 +8,7 @@ #pragma once -#include "duckdb/common/unique_ptr.hpp" -#include "duckdb/common/likely.hpp" -#include "duckdb/common/memory_safety.hpp" - -#include -#include - -namespace duckdb { - -// This implementation is taken from the llvm-project, at this commit hash: -// https://github.com/llvm/llvm-project/blob/08bb121835be432ac52372f92845950628ce9a4a/libcxx/include/__memory/shared_ptr.h#353 -// originally named '__compatible_with' - -#if _LIBCPP_STD_VER >= 17 -template -struct __bounded_convertible_to_unbounded : std::false_type {}; - -template -struct __bounded_convertible_to_unbounded<_Up[_Np], T> : std::is_same, _Up[]> {}; - -template -struct compatible_with_t : std::_Or, __bounded_convertible_to_unbounded> {}; -#else -template -struct compatible_with_t : std::is_convertible {}; // NOLINT: invalid case style -#endif // _LIBCPP_STD_VER >= 17 - -} // namespace duckdb - +#include "duckdb/common/compatible_with_ipp.hpp" #include "duckdb/common/shared_ptr_ipp.hpp" #include "duckdb/common/weak_ptr_ipp.hpp" #include "duckdb/common/enable_shared_from_this_ipp.hpp" - -namespace duckdb { - -template -using unsafe_shared_ptr = shared_ptr; - -template -using unsafe_weak_ptr = weak_ptr; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/shared_ptr_ipp.hpp b/src/duckdb/src/include/duckdb/common/shared_ptr_ipp.hpp index d046dc141..8668bbb4c 100644 --- a/src/duckdb/src/include/duckdb/common/shared_ptr_ipp.hpp +++ b/src/duckdb/src/include/duckdb/common/shared_ptr_ipp.hpp @@ -1,3 +1,7 @@ +#pragma once + +#include "duckdb/common/compatible_with_ipp.hpp" + namespace duckdb { template @@ -79,10 +83,11 @@ class shared_ptr { // NOLINT: invalid case style shared_ptr(shared_ptr &&ref) noexcept // NOLINT: not marked as explicit : internal(std::move(ref.internal)) { } + // move constructor #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - shared_ptr(shared_ptr &&other) // NOLINT: not marked as explicit + shared_ptr(shared_ptr &&other) noexcept : internal(std::move(other.internal)) { } @@ -115,7 +120,7 @@ class shared_ptr { // NOLINT: invalid case style ~shared_ptr() = default; // Assign from shared_ptr copy - shared_ptr &operator=(const shared_ptr &other) noexcept { + shared_ptr &operator=(const shared_ptr &other) noexcept { if (this == &other) { return *this; } @@ -130,13 +135,13 @@ class shared_ptr { // NOLINT: invalid case style } // Assign from moved shared_ptr - shared_ptr &operator=(shared_ptr &&other) noexcept { + shared_ptr &operator=(shared_ptr &&other) noexcept { // Create a new shared_ptr using the move constructor, then swap out the ownership to *this shared_ptr(std::move(other)).swap(*this); return *this; } template ::value, int>::type = 0> - shared_ptr &operator=(shared_ptr &&other) { + shared_ptr &operator=(shared_ptr &&other) { shared_ptr(std::move(other)).swap(*this); return *this; } @@ -146,7 +151,7 @@ class shared_ptr { // NOLINT: invalid case style typename std::enable_if::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr &operator=(unique_ptr &&ref) { + shared_ptr &operator=(unique_ptr &&ref) { shared_ptr(std::move(ref)).swap(*this); return *this; } @@ -265,4 +270,7 @@ class shared_ptr { // NOLINT: invalid case style } }; +template +using unsafe_shared_ptr = shared_ptr; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp b/src/duckdb/src/include/duckdb/common/sort/comparators.hpp deleted file mode 100644 index 5f3cd3807..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp +++ /dev/null @@ -1,65 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/comparators.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -struct SortLayout; -struct SBScanState; - -using ValidityBytes = RowLayout::ValidityBytes; - -struct Comparators { -public: - //! Whether a tie between two blobs can be broken - static bool TieIsBreakable(const idx_t &col_idx, const data_ptr_t &row_ptr, const SortLayout &sort_layout); - //! Compares the tuples that a being read from in the 'left' and 'right blocks during merge sort - //! (only in case we cannot simply 'memcmp' - if there are blob columns) - static int CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort); - //! Compare two blob values - static int CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type); - -private: - //! Compares two blob values that were initially tied by their prefix - static int BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external); - //! Compare two fixed-size values - template - static int TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr); - - //! Compare two values at the pointers (can be recursive if nested type) - static int CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid); - //! Compares two fixed-size values at the given pointers - template - static int TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr); - //! Compares two string values at the given pointers - static int CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid); - //! Compares two struct values at the given pointers (recursive) - static int CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid); - static int CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid, - idx_t array_size); - //! Compare two list values at the pointers (can be recursive if nested type) - static int CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid); - //! Compares a list of fixed-size values - template - static int TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const ValidityBytes &left_validity, - const ValidityBytes &right_validity, const idx_t &count); - - //! Unwizzles an offset into a pointer - static void UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); - //! Swizzles a pointer into an offset - static void SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp deleted file mode 100644 index c935a713a..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ /dev/null @@ -1,710 +0,0 @@ -/* -pdqsort.h - Pattern-defeating quicksort. - -Copyright (c) 2021 Orson Peters - -This software is provided 'as-is', without any express or implied warranty. In no event will the -authors be held liable for any damages arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, including commercial -applications, and to alter it and redistribute it freely, subject to the following restrictions: - -1. The origin of this software must not be misrepresented; you must not claim that you wrote the - original software. If you use this software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - -2. Altered source versions must be plainly marked as such, and must not be misrepresented as - being the original software. - -3. This notice may not be removed or altered from any source distribution. -*/ - -#pragma once - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/unique_ptr.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb_pdqsort { - -using duckdb::data_ptr_t; -using duckdb::data_t; -using duckdb::FastMemcmp; -using duckdb::FastMemcpy; -using duckdb::idx_t; -using duckdb::make_unsafe_uniq_array_uninitialized; -using duckdb::unique_ptr; -using duckdb::unsafe_unique_array; - -// NOLINTBEGIN - -enum { - // Partitions below this size are sorted using insertion sort. - insertion_sort_threshold = 24, - - // Partitions above this size use Tukey's ninther to select the pivot. - ninther_threshold = 128, - - // When we detect an already sorted partition, attempt an insertion sort that allows this - // amount of element moves before giving up. - partial_insertion_sort_limit = 8, - - // Must be multiple of 8 due to loop unrolling, and < 256 to fit in unsigned char. - block_size = 64, - - // Cacheline size, assumes power of two. - cacheline_size = 64 - -}; - -// Returns floor(log2(n)), assumes n > 0. -template -inline int log2(T n) { - int log = 0; - while (n >>= 1) { - ++log; - } - return log; -} - -struct PDQConstants { - PDQConstants(idx_t entry_size, idx_t comp_offset, idx_t comp_size, data_ptr_t end) - : entry_size(entry_size), comp_offset(comp_offset), comp_size(comp_size), - tmp_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), tmp_buf(tmp_buf_ptr.get()), - iter_swap_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - iter_swap_buf(iter_swap_buf_ptr.get()), - swap_offsets_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - swap_offsets_buf(swap_offsets_buf_ptr.get()), end(end) { - } - - const duckdb::idx_t entry_size; - const idx_t comp_offset; - const idx_t comp_size; - - unsafe_unique_array tmp_buf_ptr; - const data_ptr_t tmp_buf; - - unsafe_unique_array iter_swap_buf_ptr; - const data_ptr_t iter_swap_buf; - - unsafe_unique_array swap_offsets_buf_ptr; - const data_ptr_t swap_offsets_buf; - - const data_ptr_t end; -}; - -struct PDQIterator { - PDQIterator(data_ptr_t ptr, const idx_t &entry_size) : ptr(ptr), entry_size(entry_size) { - } - - inline PDQIterator(const PDQIterator &other) : ptr(other.ptr), entry_size(other.entry_size) { - } - - inline const data_ptr_t &operator*() const { - return ptr; - } - - inline PDQIterator &operator++() { - ptr += entry_size; - return *this; - } - - inline PDQIterator &operator--() { - ptr -= entry_size; - return *this; - } - - inline PDQIterator operator++(int) { - auto tmp = *this; - ptr += entry_size; - return tmp; - } - - inline PDQIterator operator--(int) { - auto tmp = *this; - ptr -= entry_size; - return tmp; - } - - inline PDQIterator operator+(const idx_t &i) const { - auto result = *this; - result.ptr += i * entry_size; - return result; - } - - inline PDQIterator operator-(const idx_t &i) const { - PDQIterator result = *this; - result.ptr -= i * entry_size; - return result; - } - - inline PDQIterator &operator=(const PDQIterator &other) { - D_ASSERT(entry_size == other.entry_size); - ptr = other.ptr; - return *this; - } - - inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { - D_ASSERT(duckdb::NumericCast(*lhs - *rhs) % lhs.entry_size == 0); - D_ASSERT(*lhs - *rhs >= 0); - return duckdb::NumericCast(*lhs - *rhs) / lhs.entry_size; - } - - inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs < *rhs; - } - - inline friend bool operator>(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs > *rhs; - } - - inline friend bool operator>=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs >= *rhs; - } - - inline friend bool operator<=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs <= *rhs; - } - - inline friend bool operator==(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs == *rhs; - } - - inline friend bool operator!=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs != *rhs; - } - -private: - data_ptr_t ptr; - const idx_t &entry_size; -}; - -static inline bool comp(const data_ptr_t &l, const data_ptr_t &r, const PDQConstants &constants) { - D_ASSERT(l == constants.tmp_buf || l == constants.swap_offsets_buf || l < constants.end); - D_ASSERT(r == constants.tmp_buf || r == constants.swap_offsets_buf || r < constants.end); - return FastMemcmp(l + constants.comp_offset, r + constants.comp_offset, constants.comp_size) < 0; -} - -static inline const data_ptr_t &GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.tmp_buf, src, constants.entry_size); - return constants.tmp_buf; -} - -static inline const data_ptr_t &SWAP_OFFSETS_GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.swap_offsets_buf, src, constants.entry_size); - return constants.swap_offsets_buf; -} - -static inline void MOVE(const data_ptr_t &dest, const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(dest == constants.tmp_buf || dest == constants.swap_offsets_buf || dest < constants.end); - D_ASSERT(src == constants.tmp_buf || src == constants.swap_offsets_buf || src < constants.end); - FastMemcpy(dest, src, constants.entry_size); -} - -static inline void iter_swap(const PDQIterator &lhs, const PDQIterator &rhs, const PDQConstants &constants) { - D_ASSERT(*lhs < constants.end); - D_ASSERT(*rhs < constants.end); - FastMemcpy(constants.iter_swap_buf, *lhs, constants.entry_size); - FastMemcpy(*lhs, *rhs, constants.entry_size); - FastMemcpy(*rhs, constants.iter_swap_buf, constants.entry_size); -} - -// Sorts [begin, end) using insertion sort with the given comparison function. -inline void insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Sorts [begin, end) using insertion sort with the given comparison function. Assumes -// *(begin - 1) is an element smaller than or equal to any element in [begin, end). -inline void unguarded_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Attempts to use insertion sort on [begin, end). Will return false if more than -// partial_insertion_sort_limit elements were moved, and abort sorting. Otherwise it will -// successfully sort and return true. -inline bool partial_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return true; - } - - std::size_t limit = 0; - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - limit += cur - sift; - } - - if (limit > partial_insertion_sort_limit) { - return false; - } - } - - return true; -} - -inline void sort2(const PDQIterator &a, const PDQIterator &b, const PDQConstants &constants) { - if (comp(*b, *a, constants)) { - iter_swap(a, b, constants); - } -} - -// Sorts the elements *a, *b and *c using comparison function comp. -inline void sort3(const PDQIterator &a, const PDQIterator &b, const PDQIterator &c, const PDQConstants &constants) { - sort2(a, b, constants); - sort2(b, c, constants); - sort2(a, b, constants); -} - -template -inline T *align_cacheline(T *p) { -#if defined(UINTPTR_MAX) && __cplusplus >= 201103L - std::uintptr_t ip = reinterpret_cast(p); -#else - std::size_t ip = reinterpret_cast(p); -#endif - ip = (ip + cacheline_size - 1) & -duckdb::UnsafeNumericCast(cacheline_size); - return reinterpret_cast(ip); -} - -inline void swap_offsets(const PDQIterator &first, const PDQIterator &last, unsigned char *offsets_l, - unsigned char *offsets_r, size_t num, bool use_swaps, const PDQConstants &constants) { - if (use_swaps) { - // This case is needed for the descending distribution, where we need - // to have proper swapping for pdqsort to remain O(n). - for (size_t i = 0; i < num; ++i) { - iter_swap(first + offsets_l[i], last - offsets_r[i], constants); - } - } else if (num > 0) { - PDQIterator l = first + offsets_l[0]; - PDQIterator r = last - offsets_r[0]; - const auto &tmp = SWAP_OFFSETS_GET_TMP(*l, constants); - MOVE(*l, *r, constants); - for (size_t i = 1; i < num; ++i) { - l = first + offsets_l[i]; - MOVE(*r, *l, constants); - r = last - offsets_r[i]; - MOVE(*l, *r, constants); - } - MOVE(*r, tmp, constants); - } -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. Uses branchless partitioning. -inline std::pair partition_right_branchless(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - if (!already_partitioned) { - iter_swap(first, last, constants); - ++first; - - // The following branchless partitioning is derived from "BlockQuicksort: How Branch - // Mispredictions don’t affect Quicksort" by Stefan Edelkamp and Armin Weiss, but - // heavily micro-optimized. - unsigned char offsets_l_storage[block_size + cacheline_size]; - unsigned char offsets_r_storage[block_size + cacheline_size]; - unsigned char *offsets_l = align_cacheline(offsets_l_storage); - unsigned char *offsets_r = align_cacheline(offsets_r_storage); - - PDQIterator offsets_l_base = first; - PDQIterator offsets_r_base = last; - size_t num_l, num_r, start_l, start_r; - num_l = num_r = start_l = start_r = 0; - - while (first < last) { - // Fill up offset blocks with elements that are on the wrong side. - // First we determine how much elements are considered for each offset block. - size_t num_unknown = last - first; - size_t left_split = num_l == 0 ? (num_r == 0 ? num_unknown / 2 : num_unknown) : 0; - size_t right_split = num_r == 0 ? (num_unknown - left_split) : 0; - - // Fill the offset blocks. - if (left_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } else { - for (unsigned char i = 0; i < left_split;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } - - if (right_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } else { - for (unsigned char i = 0; i < right_split;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } - - // Swap elements and update block sizes and first/last boundaries. - size_t num = std::min(num_l, num_r); - swap_offsets(offsets_l_base, offsets_r_base, offsets_l + start_l, offsets_r + start_r, num, num_l == num_r, - constants); - num_l -= num; - num_r -= num; - start_l += num; - start_r += num; - - if (num_l == 0) { - start_l = 0; - offsets_l_base = first; - } - - if (num_r == 0) { - start_r = 0; - offsets_r_base = last; - } - } - - // We have now fully identified [first, last)'s proper position. Swap the last elements. - if (num_l) { - offsets_l += start_l; - while (num_l--) { - iter_swap(offsets_l_base + offsets_l[num_l], --last, constants); - } - first = last; - } - if (num_r) { - offsets_r += start_r; - while (num_r--) { - iter_swap(offsets_r_base - offsets_r[num_r], first, constants), ++first; - } - last = first; - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. -inline std::pair partition_right(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - - // Keep swapping pairs of elements that are on the wrong side of the pivot. Previously - // swapped pairs guard the searches, which is why the first iteration is special-cased - // above. - while (first < last) { - iter_swap(first, last, constants); - while (comp(*++first, pivot, constants)) { - } - while (!comp(*--last, pivot, constants)) { - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Similar function to the one above, except elements equal to the pivot are put to the left of -// the pivot and it doesn't check or return if the passed sequence already was partitioned. -// Since this is rarely used (the many equal case), and in that case pdqsort already has O(n) -// performance, no block quicksort is applied here for simplicity. -inline PDQIterator partition_left(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - while (comp(pivot, *--last, constants)) { - } - - if (last + 1 == end) { - while (first < last && !comp(pivot, *++first, constants)) { - } - } else { - while (!comp(pivot, *++first, constants)) { - } - } - - while (first < last) { - iter_swap(first, last, constants); - while (comp(pivot, *--last, constants)) { - } - while (!comp(pivot, *++first, constants)) { - } - } - - PDQIterator pivot_pos = last; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return pivot_pos; -} - -template -inline void pdqsort_loop(PDQIterator begin, const PDQIterator &end, const PDQConstants &constants, int bad_allowed, - bool leftmost = true) { - // Use a while loop for tail recursion elimination. - while (true) { - idx_t size = end - begin; - - // Insertion sort is faster for small arrays. - if (size < insertion_sort_threshold) { - if (leftmost) { - insertion_sort(begin, end, constants); - } else { - unguarded_insertion_sort(begin, end, constants); - } - return; - } - - // Choose pivot as median of 3 or pseudomedian of 9. - idx_t s2 = size / 2; - if (size > ninther_threshold) { - sort3(begin, begin + s2, end - 1, constants); - sort3(begin + 1, begin + (s2 - 1), end - 2, constants); - sort3(begin + 2, begin + (s2 + 1), end - 3, constants); - sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), constants); - iter_swap(begin, begin + s2, constants); - } else { - sort3(begin + s2, begin, end - 1, constants); - } - - // If *(begin - 1) is the end of the right partition of a previous partition operation - // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our - // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in - // the left partition, greater elements in the right partition. We do not have to - // recurse on the left partition, since it's sorted (all equal). - if (!leftmost && !comp(*(begin - 1), *begin, constants)) { - begin = partition_left(begin, end, constants) + 1; - continue; - } - - // Partition and get results. - std::pair part_result = - Branchless ? partition_right_branchless(begin, end, constants) : partition_right(begin, end, constants); - PDQIterator pivot_pos = part_result.first; - bool already_partitioned = part_result.second; - - // Check for a highly unbalanced partition. - idx_t l_size = pivot_pos - begin; - idx_t r_size = end - (pivot_pos + 1); - bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; - - // If we got a highly unbalanced partition we shuffle elements to break many patterns. - if (highly_unbalanced) { - // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). - // if (--bad_allowed == 0) { - // std::make_heap(begin, end, comp); - // std::sort_heap(begin, end, comp); - // return; - // } - - if (l_size >= insertion_sort_threshold) { - iter_swap(begin, begin + l_size / 4, constants); - iter_swap(pivot_pos - 1, pivot_pos - l_size / 4, constants); - - if (l_size > ninther_threshold) { - iter_swap(begin + 1, begin + (l_size / 4 + 1), constants); - iter_swap(begin + 2, begin + (l_size / 4 + 2), constants); - iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1), constants); - iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2), constants); - } - } - - if (r_size >= insertion_sort_threshold) { - iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4), constants); - iter_swap(end - 1, end - r_size / 4, constants); - - if (r_size > ninther_threshold) { - iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4), constants); - iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4), constants); - iter_swap(end - 2, end - (1 + r_size / 4), constants); - iter_swap(end - 3, end - (2 + r_size / 4), constants); - } - } - } else { - // If we were decently balanced and we tried to sort an already partitioned - // sequence try to use insertion sort. - if (already_partitioned && partial_insertion_sort(begin, pivot_pos, constants) && - partial_insertion_sort(pivot_pos + 1, end, constants)) { - return; - } - } - - // Sort the left partition first using recursion and do tail recursion elimination for - // the right-hand partition. - pdqsort_loop(begin, pivot_pos, constants, bad_allowed, leftmost); - begin = pivot_pos + 1; - leftmost = false; - } -} - -inline void pdqsort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} - -inline void pdqsort_branchless(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} -// NOLINTEND - -} // namespace duckdb_pdqsort diff --git a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp b/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp deleted file mode 100644 index 8170875e8..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp +++ /dev/null @@ -1,245 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/partition_state.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" - -namespace duckdb { - -class PartitionGlobalHashGroup { -public: - using GlobalSortStatePtr = unique_ptr; - using Orders = vector; - using Types = vector; - using OrderMasks = unordered_map; - - PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, const Orders &orders, - const Types &payload_types, bool external); - - inline int ComparePartitions(const SBIterator &left, const SBIterator &right) { - int part_cmp = 0; - if (partition_layout.all_constant) { - part_cmp = FastMemcmp(left.entry_ptr, right.entry_ptr, partition_layout.comparison_size); - } else { - part_cmp = Comparators::CompareTuple(left.scan, right.scan, left.entry_ptr, right.entry_ptr, - partition_layout, left.external); - } - return part_cmp; - } - - void ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks); - - GlobalSortStatePtr global_sort; - atomic count; - - // Mask computation - SortLayout partition_layout; -}; - -class PartitionGlobalSinkState { -public: - using HashGroupPtr = unique_ptr; - using Orders = vector; - using Types = vector; - - using GroupingPartition = unique_ptr; - using GroupingAppend = unique_ptr; - - static void GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, const Orders &order_bys, - const vector> &partitions_stats); - - PartitionGlobalSinkState(ClientContext &context, const vector> &partition_bys, - const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); - virtual ~PartitionGlobalSinkState() = default; - - bool HasMergeTasks() const; - - unique_ptr CreatePartition(idx_t new_bits) const; - void SyncPartitioning(const PartitionGlobalSinkState &other); - - void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - - virtual void OnBeginMerge() {}; - virtual void OnSortedPartition(const idx_t hash_bin_p) {}; - - ClientContext &context; - BufferManager &buffer_manager; - Allocator &allocator; - mutex lock; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr grouping_data; - //! Payload plus hash column - shared_ptr grouping_types_ptr; - //! The number of radix bits if this partition is being synced with another - idx_t fixed_bits; - - // OVER(...) (sorting) - Orders partitions; - Orders orders; - const Types payload_types; - vector hash_groups; - bool external; - // Reverse lookup from hash bins to non-empty hash groups - vector bin_groups; - - // OVER() (no sorting) - unique_ptr rows; - unique_ptr strings; - - // Threading - idx_t memory_per_thread; - idx_t max_bits; - atomic count; - -private: - void ResizeGroupingData(idx_t cardinality); - void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); -}; - -class PartitionLocalSinkState { -public: - using LocalSortStatePtr = unique_ptr; - - PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p); - - // Global state - PartitionGlobalSinkState &gstate; - Allocator &allocator; - - // Shared expression evaluation - ExpressionExecutor executor; - DataChunk group_chunk; - DataChunk payload_chunk; - size_t sort_cols; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr local_partition; - unique_ptr local_append; - - // OVER(ORDER BY...) (only sorting) - LocalSortStatePtr local_sort; - - // OVER() (no sorting) - RowLayout payload_layout; - unique_ptr rows; - unique_ptr strings; - - //! Compute the hash values - void Hash(DataChunk &input_chunk, Vector &hash_vector); - //! Sink an input chunk - void Sink(DataChunk &input_chunk); - //! Merge the state into the global state. - void Combine(); -}; - -enum class PartitionSortStage : uint8_t { INIT, SCAN, PREPARE, MERGE, SORTED, FINISHED }; - -class PartitionLocalMergeState; - -class PartitionGlobalMergeState { -public: - using GroupDataPtr = unique_ptr; - - // OVER(PARTITION BY...) - PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, hash_t hash_bin); - - // OVER(ORDER BY...) - explicit PartitionGlobalMergeState(PartitionGlobalSinkState &sink); - - bool IsFinished() const { - return stage == PartitionSortStage::FINISHED; - } - - bool AssignTask(PartitionLocalMergeState &local_state); - bool TryPrepareNextStage(); - void CompleteTask(); - - PartitionGlobalSinkState &sink; - GroupDataPtr group_data; - PartitionGlobalHashGroup *hash_group; - const idx_t group_idx; - vector column_ids; - TupleDataParallelScanState chunk_state; - GlobalSortState *global_sort; - const idx_t memory_per_thread; - const idx_t num_threads; - -private: - mutable mutex lock; - atomic stage; - idx_t total_tasks; - idx_t tasks_assigned; - idx_t tasks_completed; -}; - -class PartitionLocalMergeState { -public: - explicit PartitionLocalMergeState(PartitionGlobalSinkState &gstate); - - bool TaskFinished() { - return finished; - } - - void Prepare(); - void Scan(); - void Merge(); - void Sorted(); - - void ExecuteTask(); - - PartitionGlobalMergeState *merge_state; - PartitionSortStage stage; - atomic finished; - - // Sorting buffers - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk payload_chunk; -}; - -class PartitionGlobalMergeStates { -public: - struct Callback { - virtual ~Callback() = default; - - virtual bool HasError() const { - return false; - } - }; - - using PartitionGlobalMergeStatePtr = unique_ptr; - - explicit PartitionGlobalMergeStates(PartitionGlobalSinkState &sink); - - bool ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback); - - vector states; -}; - -class PartitionMergeEvent : public BasePipelineEvent { -public: - PartitionMergeEvent(PartitionGlobalSinkState &gstate_p, Pipeline &pipeline_p, const PhysicalOperator &op_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), merge_states(gstate_p), op(op_p) { - } - - PartitionGlobalSinkState &gstate; - PartitionGlobalMergeStates merge_states; - const PhysicalOperator &op; - -public: - void Schedule() override; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sort.hpp b/src/duckdb/src/include/duckdb/common/sort/sort.hpp deleted file mode 100644 index 188ea2127..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sort.hpp +++ /dev/null @@ -1,290 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sort.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class RowLayout; -struct LocalSortState; - -struct SortConstants { - static constexpr idx_t VALUES_PER_RADIX = 256; - static constexpr idx_t MSD_RADIX_LOCATIONS = VALUES_PER_RADIX + 1; - static constexpr idx_t INSERTION_SORT_THRESHOLD = 24; - static constexpr idx_t MSD_RADIX_SORT_SIZE_THRESHOLD = 4; -}; - -struct SortLayout { -public: - SortLayout() { - } - explicit SortLayout(const vector &orders); - SortLayout GetPrefixComparisonLayout(idx_t num_prefix_cols) const; - -public: - idx_t column_count; - vector order_types; - vector order_by_null_types; - vector logical_types; - - bool all_constant; - vector constant_size; - vector column_sizes; - vector prefix_lengths; - vector stats; - vector has_null; - - idx_t comparison_size; - idx_t entry_size; - - RowLayout blob_layout; - unordered_map sorting_to_blob_col; -}; - -struct GlobalSortState { -public: - GlobalSortState(ClientContext &context, const vector &orders, RowLayout &payload_layout); - - //! Add local state sorted data to this global state - void AddLocalState(LocalSortState &local_sort_state); - //! Prepares the GlobalSortState for the merge sort phase (after completing radix sort phase) - void PrepareMergePhase(); - //! Initializes the global sort state for another round of merging - void InitializeMergeRound(); - //! Completes the cascaded merge sort round. - //! Pass true if you wish to use the radix data for further comparisons. - void CompleteMergeRound(bool keep_radix_data = false); - //! Print the sorted data to the console. - void Print(); - -public: - //! The client context - ClientContext &context; - //! The lock for updating the order global state - mutex lock; - //! The buffer manager - BufferManager &buffer_manager; - - //! Sorting and payload layouts - const SortLayout sort_layout; - const RowLayout payload_layout; - - //! Sorted data - vector> sorted_blocks; - vector>> sorted_blocks_temp; - unique_ptr odd_one_out; - - //! Pinned heap data (if sorting in memory) - vector> heap_blocks; - vector pinned_blocks; - - //! Capacity (number of rows) used to initialize blocks - idx_t block_capacity; - //! Whether we are doing an external sort - bool external; - - //! Progress in merge path stage - idx_t pair_idx; - idx_t num_pairs; - idx_t l_start; - idx_t r_start; -}; - -struct LocalSortState { -public: - LocalSortState(); - - //! Initialize the layouts and RowDataCollections - void Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p); - //! Sink one DataChunk into the local sort state - void SinkChunk(DataChunk &sort, DataChunk &payload); - //! Size of accumulated data in bytes - idx_t SizeInBytes() const; - //! Sort the data accumulated so far - void Sort(GlobalSortState &global_sort_state, bool reorder_heap); - //! Concatenate the blocks held by a RowDataCollection into a single block - static unique_ptr ConcatenateBlocks(RowDataCollection &row_data); - -private: - //! Sorts the data in the newly created SortedBlock - void SortInMemory(); - //! Re-order the local state after sorting - void ReOrder(GlobalSortState &gstate, bool reorder_heap); - //! Re-order a SortedData object after sorting - void ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap); - -public: - //! Whether this local state has been initialized - bool initialized; - //! The buffer manager - BufferManager *buffer_manager; - //! The sorting and payload layouts - const SortLayout *sort_layout; - const RowLayout *payload_layout; - //! Radix/memcmp sortable data - unique_ptr radix_sorting_data; - //! Variable sized sorting data and accompanying heap - unique_ptr blob_sorting_data; - unique_ptr blob_sorting_heap; - //! Payload data and accompanying heap - unique_ptr payload_data; - unique_ptr payload_heap; - //! Sorted data - vector> sorted_blocks; - -private: - //! Selection vector and addresses for scattering the data to rows - const SelectionVector &sel_ptr = *FlatVector::IncrementalSelectionVector(); - Vector addresses = Vector(LogicalType::POINTER); -}; - -struct MergeSorter { -public: - MergeSorter(GlobalSortState &state, BufferManager &buffer_manager); - - //! Finds and merges partitions until the current cascaded merge round is finished - void PerformInMergeRound(); - -private: - //! The global sorting state - GlobalSortState &state; - //! The sorting and payload layouts - BufferManager &buffer_manager; - const SortLayout &sort_layout; - - //! The left and right reader - unique_ptr left; - unique_ptr right; - - //! Input and output blocks - unique_ptr left_input; - unique_ptr right_input; - SortedBlock *result; - -private: - //! Computes the left and right block that will be merged next (Merge Path partition) - void GetNextPartition(); - //! Finds the boundary of the next partition using binary search - void GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx); - //! Compare values within SortedBlocks using a global index - int CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx); - - //! Finds the next partition and merges it - void MergePartition(); - - //! Computes how the next 'count' tuples should be merged by setting the 'left_smaller' array - void ComputeMerge(const idx_t &count, bool left_smaller[]); - - //! Merges the radix sorting blocks according to the 'left_smaller' array - void MergeRadix(const idx_t &count, const bool left_smaller[]); - //! Merges SortedData according to the 'left_smaller' array - void MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices); - //! Merges constant size rows according to the 'left_smaller' array - void MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, idx_t &r_entry_idx, - const idx_t &r_count, RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, - const bool left_smaller[], idx_t &copied, const idx_t &count); - //! Flushes constant size rows into the result - void FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count); - //! Flushes blob rows and accompanying heap - void FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, BufferHandle &target_heap_handle, - data_ptr_t &target_heap_ptr, idx_t &copied, const idx_t &count); -}; - -struct SBIterator { - static int ComparisonValue(ExpressionType comparison); - - SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p = 0); - - inline idx_t GetIndex() const { - return entry_idx; - } - - inline void SetIndex(idx_t entry_idx_p) { - const auto new_block_idx = entry_idx_p / block_capacity; - if (new_block_idx != scan.block_idx) { - scan.SetIndices(new_block_idx, 0); - if (new_block_idx < block_count) { - scan.PinRadix(scan.block_idx); - block_ptr = scan.RadixPtr(); - if (!all_constant) { - scan.PinData(*scan.sb->blob_sorting_data); - } - } - } - - scan.entry_idx = entry_idx_p % block_capacity; - entry_ptr = block_ptr + scan.entry_idx * entry_size; - entry_idx = entry_idx_p; - } - - inline SBIterator &operator++() { - if (++scan.entry_idx < block_capacity) { - entry_ptr += entry_size; - ++entry_idx; - } else { - SetIndex(entry_idx + 1); - } - - return *this; - } - - inline SBIterator &operator--() { - if (scan.entry_idx) { - --scan.entry_idx; - --entry_idx; - entry_ptr -= entry_size; - } else { - SetIndex(entry_idx - 1); - } - - return *this; - } - - inline bool Compare(const SBIterator &other, const SortLayout &prefix) const { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(entry_ptr, other.entry_ptr, prefix.comparison_size); - } else { - comp_res = Comparators::CompareTuple(scan, other.scan, entry_ptr, other.entry_ptr, prefix, external); - } - - return comp_res <= cmp; - } - - inline bool Compare(const SBIterator &other) const { - return Compare(other, sort_layout); - } - - // Fixed comparison parameters - const SortLayout &sort_layout; - const idx_t block_count; - const idx_t block_capacity; - const size_t entry_size; - const bool all_constant; - const bool external; - const int cmp; - - // Iteration state - SBScanState scan; - idx_t entry_idx; - data_ptr_t block_ptr; - data_ptr_t entry_ptr; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp b/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp deleted file mode 100644 index b6941bda2..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp +++ /dev/null @@ -1,165 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sorted_block.hpp -// -// -//===----------------------------------------------------------------------===// -#pragma once - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" - -namespace duckdb { - -class BufferManager; -struct RowDataBlock; -struct SortLayout; -struct GlobalSortState; - -enum class SortedDataType { BLOB, PAYLOAD }; - -//! Object that holds sorted rows, and an accompanying heap if there are blobs -struct SortedData { -public: - SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, GlobalSortState &state); - //! Number of rows that this object holds - idx_t Count(); - //! Initialize new block to write to - void CreateBlock(); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index); - //! Unswizzles all - void Unswizzle(); - -public: - const SortedDataType type; - //! Layout of this data - const RowLayout layout; - //! Data and heap blocks - vector> data_blocks; - vector> heap_blocks; - //! Whether the pointers in this sorted data are swizzled - bool swizzled; - -private: - //! The buffer manager - BufferManager &buffer_manager; - //! The global state - GlobalSortState &state; -}; - -//! Block that holds sorted rows: radix, blob and payload data -struct SortedBlock { -public: - SortedBlock(BufferManager &buffer_manager, GlobalSortState &gstate); - //! Number of rows that this object holds - idx_t Count() const; - //! Initialize this block to write data to - void InitializeWrite(); - //! Init new block to write to - void CreateBlock(); - //! Fill this sorted block by appending the blocks held by a vector of sorted blocks - void AppendSortedBlocks(vector> &sorted_blocks); - //! Locate the block and entry index of a row in this block, - //! given an index between 0 and the total number of rows in this block - void GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx); - - //! Size (in bytes) of the heap of this block - idx_t HeapSize() const; - //! Total size (in bytes) of this block - idx_t SizeInBytes() const; - -public: - //! Radix/memcmp sortable data - vector> radix_sorting_data; - //! Variable sized sorting data - unique_ptr blob_sorting_data; - //! Payload data - unique_ptr payload_data; - -private: - //! Buffer manager, global state, and sorting layout constants - BufferManager &buffer_manager; - GlobalSortState &state; - const SortLayout &sort_layout; - const RowLayout &payload_layout; -}; - -//! State used to scan a SortedBlock e.g. during merge sort -struct SBScanState { -public: - SBScanState(BufferManager &buffer_manager, GlobalSortState &state); - - void PinRadix(idx_t block_idx_to); - void PinData(SortedData &sd); - - data_ptr_t RadixPtr() const; - data_ptr_t DataPtr(SortedData &sd) const; - data_ptr_t HeapPtr(SortedData &sd) const; - data_ptr_t BaseHeapPtr(SortedData &sd) const; - - idx_t Remaining() const; - - void SetIndices(idx_t block_idx_to, idx_t entry_idx_to); - -public: - BufferManager &buffer_manager; - const SortLayout &sort_layout; - GlobalSortState &state; - - SortedBlock *sb; - - idx_t block_idx; - idx_t entry_idx; - - BufferHandle radix_handle; - - BufferHandle blob_sorting_data_handle; - BufferHandle blob_sorting_heap_handle; - - BufferHandle payload_data_handle; - BufferHandle payload_heap_handle; -}; - -//! Used to scan the data into DataChunks after sorting -struct PayloadScanner { -public: - PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush = true); - explicit PayloadScanner(GlobalSortState &global_sort_state, bool flush = true); - - //! Scan a single block - PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush = false); - - //! The type layout of the payload - inline const vector &GetPayloadTypes() const { - return scanner->GetTypes(); - } - - //! The number of rows scanned so far - inline idx_t Scanned() const { - return scanner->Scanned(); - } - - //! The number of remaining rows - inline idx_t Remaining() const { - return scanner->Remaining(); - } - - //! Scans the next data chunk from the sorted data - void Scan(DataChunk &chunk); - -private: - //! The sorted data being scanned - unique_ptr rows; - unique_ptr heap; - //! The actual scanner - unique_ptr scanner; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/full_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/full_sort.hpp new file mode 100644 index 000000000..ac79c0465 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sorting/full_sort.hpp @@ -0,0 +1,69 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sorting/full_sort.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sorting/sort_strategy.hpp" + +namespace duckdb { + +class FullSort : public SortStrategy { +public: + using Orders = vector; + + FullSort(ClientContext &client, const vector &order_bys, const Types &payload_types, + bool require_payload = false); + +public: + //===--------------------------------------------------------------------===// + // Sink Interface + //===--------------------------------------------------------------------===// + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &client) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const override; + ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, + const ProgressData source_progress) const override; + +public: + //===--------------------------------------------------------------------===// + // Source Interface + //===--------------------------------------------------------------------===// + unique_ptr GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const override; + +public: + //===--------------------------------------------------------------------===// + // Non-Standard Interface + //===--------------------------------------------------------------------===// + SourceResultType MaterializeColumnData(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + HashGroupPtr GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const override; + + SourceResultType MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + SortedRunPtr GetSortedRun(ClientContext &client, idx_t hash_bin, OperatorSourceInput &source) const override; + + const ChunkRows &GetHashGroups(GlobalSourceState &global_state) const override; + +public: + // OVER(...) (sorting) + Orders orders; + //! Are we creating a dummy payload column? + bool force_payload = false; + // Key columns that must be computed + vector> sort_exprs; + //! Common sort description + unique_ptr sort; + +private: + SourceResultType MaterializeSortedData(ExecutionContext &context, bool build_runs, + OperatorSourceInput &source) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp index 374133692..f96168fca 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp @@ -8,15 +8,16 @@ #pragma once -#include "duckdb/common/sorting/sort.hpp" +#include "duckdb/common/sorting/sort_strategy.hpp" namespace duckdb { -class HashedSort { +class HashedSort : public SortStrategy { public: using Orders = vector; using Types = vector; using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; static void GenerateOrderings(Orders &partitions, Orders &orders, const vector> &partition_bys, const Orders &order_bys, @@ -24,49 +25,56 @@ class HashedSort { HashedSort(ClientContext &context, const vector> &partition_bys, const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); + const vector> &partitions_stats, idx_t estimated_cardinality, + bool require_payload = false); public: //===--------------------------------------------------------------------===// // Sink Interface //===--------------------------------------------------------------------===// - unique_ptr GetLocalSinkState(ExecutionContext &context) const; - unique_ptr GetGlobalSinkState(ClientContext &client) const; - SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const; - SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const; - SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &client) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const override; ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, - const ProgressData source_progress) const; + const ProgressData source_progress) const override; + void Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const override; public: //===--------------------------------------------------------------------===// // Source Interface //===--------------------------------------------------------------------===// - unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const; - unique_ptr GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const; + unique_ptr GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const override; public: //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// - SinkFinalizeType MaterializeHashGroups(Pipeline &pipeline, Event &event, const PhysicalOperator &op, - OperatorSinkFinalizeInput &finalize) const; - vector &GetHashGroups(GlobalSourceState &global_state) const; + void SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize) override; + + SourceResultType MaterializeColumnData(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + HashGroupPtr GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const override; + + SourceResultType MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + SortedRunPtr GetSortedRun(ClientContext &client, idx_t hash_bin, OperatorSourceInput &source) const override; + + const ChunkRows &GetHashGroups(GlobalSourceState &global_state) const override; public: - ClientContext &client; //! The host's estimated row count const idx_t estimated_cardinality; - // OVER(...) (sorting) + //! The PARTITION BY sorting Orders partitions; + //! The ORDER BY sorting Orders orders; - idx_t sort_col_count; - Types payload_types; - // Input columns in the sorted output - vector scan_ids; - // Key columns in the sorted output - vector sort_ids; + //! The partition columns + vector partition_ids; + //! Are we creating a dummy payload column? + bool force_payload = false; // Key columns that must be computed vector> sort_exprs; //! Common sort description diff --git a/src/duckdb/src/include/duckdb/common/sorting/natural_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/natural_sort.hpp new file mode 100644 index 000000000..e7d9c513c --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sorting/natural_sort.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sorting/natural_sort.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sorting/sort_strategy.hpp" + +namespace duckdb { + +class NaturalSort : public SortStrategy { +public: + using Types = vector; + using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; + + explicit NaturalSort(const Types &payload_types); + +public: + //===--------------------------------------------------------------------===// + // Sink Interface + //===--------------------------------------------------------------------===// + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + unique_ptr GetGlobalSinkState(ClientContext &client) const override; + SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; + SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; + SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const override; + ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, + const ProgressData source_progress) const override; + +public: + //===--------------------------------------------------------------------===// + // Source Interface + //===--------------------------------------------------------------------===// + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + unique_ptr GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const override; + +public: + //===--------------------------------------------------------------------===// + // Non-Standard Interface + //===--------------------------------------------------------------------===// + SourceResultType MaterializeColumnData(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + HashGroupPtr GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const override; + + SourceResultType MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const override; + SortedRunPtr GetSortedRun(ClientContext &client, idx_t hash_bin, OperatorSourceInput &source) const override; + + const ChunkRows &GetHashGroups(GlobalSourceState &global_state) const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp index 597b8261b..de1e33f3b 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp @@ -8,25 +8,44 @@ #pragma once -#include "duckdb/common/sorting/sorted_run.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/execution/physical_operator_states.hpp" +#include "duckdb/execution/progress_data.hpp" #include "duckdb/common/sorting/sort_projection_column.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" namespace duckdb { class SortLocalSinkState; class SortGlobalSinkState; + class SortLocalSourceState; class SortGlobalSourceState; +class SortedRun; +class SortedRunScanState; + +class SortedRunMerger; +class SortedRunMergerLocalState; +class SortedRunMergerGlobalState; + +class TupleDataLayout; +class ColumnDataCollection; + //! Class that sorts the data, follows the PhysicalOperator interface class Sort { friend class SortLocalSinkState; friend class SortGlobalSinkState; + friend class SortLocalSourceState; friend class SortGlobalSourceState; + friend class SortedRun; + friend class SortedRunScanState; + + friend class SortedRunMerger; + friend class SortedRunMergerLocalState; + friend class SortedRunMergerGlobalState; + public: Sort(ClientContext &context, const vector &orders, const vector &input_types, vector projection_map, bool is_index_sort = false); @@ -45,7 +64,7 @@ class Sort { vector input_projection_map; vector output_projection_columns; - //! Whether to force an external sort + //! Whether to force an approximate sort bool is_index_sort; public: diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp index 8d8d86aca..68d7594be 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort_key.hpp @@ -45,7 +45,7 @@ struct SortKey; template struct SortKeyNoPayload { protected: - SortKeyNoPayload() = default; + SortKeyNoPayload() = default; // NOLINT friend SORT_KEY; public: @@ -63,7 +63,7 @@ struct SortKeyNoPayload { template struct SortKeyPayload { protected: - SortKeyPayload() = default; + SortKeyPayload() = default; // NOLINT friend SORT_KEY; public: @@ -93,7 +93,7 @@ inline bool SortKeyLessThan<1>(const uint64_t *const &lhs, const uint64_t *const template struct FixedSortKey : std::conditional, SortKeyNoPayload>::type { protected: - FixedSortKey() = default; + FixedSortKey() = default; // NOLINT friend SORT_KEY; public: @@ -102,7 +102,7 @@ struct FixedSortKey : std::conditional, So void ByteSwap() { auto &sort_key = static_cast(*this); for (idx_t i = 0; i < SORT_KEY::PARTS; i++) { - (&sort_key.part0)[i] = BSwap((&sort_key.part0)[i]); + (&sort_key.part0)[i] = BSwapIfLE((&sort_key.part0)[i]); } } @@ -163,7 +163,7 @@ struct FixedSortKey : std::conditional, So template struct VariableSortKey : std::conditional, SortKeyNoPayload>::type { protected: - VariableSortKey() = default; + VariableSortKey() = default; // NOLINT friend SORT_KEY; public: @@ -172,7 +172,7 @@ struct VariableSortKey : std::conditional, void ByteSwap() { auto &sort_key = static_cast(*this); for (idx_t i = 0; i < SORT_KEY::PARTS; i++) { - (&sort_key.part0)[i] = BSwap((&sort_key.part0)[i]); + (&sort_key.part0)[i] = BSwapIfLE((&sort_key.part0)[i]); } } diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort_strategy.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort_strategy.hpp new file mode 100644 index 000000000..75f89535a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sorting/sort_strategy.hpp @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sorting/sort_strategy.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/sorting/sort.hpp" + +namespace duckdb { + +class SortStrategy { +public: + using Types = vector; + using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; + + static unique_ptr Factory(ClientContext &context, const vector> &partition_bys, + const vector &order_bys, const Types &payload_types, + const vector> &partitions_stats, + idx_t estimated_cardinality, bool require_payload = false); + + explicit SortStrategy(const Types &input_types); + virtual ~SortStrategy() = default; + +public: + //===--------------------------------------------------------------------===// + // Sink Interface + //===--------------------------------------------------------------------===// + virtual unique_ptr GetLocalSinkState(ExecutionContext &context) const = 0; + virtual unique_ptr GetGlobalSinkState(ClientContext &client) const = 0; + virtual SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const = 0; + virtual SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const = 0; + virtual SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const = 0; + virtual ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, + const ProgressData source_progress) const = 0; + virtual void Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const; + +public: + //===--------------------------------------------------------------------===// + // Source Interface + //===--------------------------------------------------------------------===// + virtual unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const; + virtual unique_ptr GetGlobalSourceState(ClientContext &context, GlobalSinkState &sink) const = 0; + +public: + //===--------------------------------------------------------------------===// + // Non-Standard Interface + //===--------------------------------------------------------------------===// + virtual void SortColumnData(ExecutionContext &context, hash_t hash_bin, OperatorSinkFinalizeInput &finalize); + + virtual SourceResultType MaterializeColumnData(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const = 0; + virtual HashGroupPtr GetColumnData(idx_t hash_bin, OperatorSourceInput &source) const = 0; + + virtual SourceResultType MaterializeSortedRun(ExecutionContext &context, idx_t hash_bin, + OperatorSourceInput &source) const = 0; + virtual SortedRunPtr GetSortedRun(ClientContext &client, idx_t hash_bin, OperatorSourceInput &source) const = 0; + + // The chunk and row counts of the hash groups. + struct ChunkRow { + idx_t chunks = 0; + idx_t count = 0; + }; + using ChunkRows = vector; + virtual const ChunkRows &GetHashGroups(GlobalSourceState &global_state) const = 0; + +public: + //! The inserted data schema + Types payload_types; + //! Input columns in the sorted output + vector scan_ids; + // Key columns in the sorted output. Needed for prefix computations. + vector sort_ids; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp index fe0d67e32..a5714cf8f 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp @@ -9,18 +9,41 @@ #pragma once #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/execution/expression_executor.hpp" namespace duckdb { +class Sort; +class SortedRun; class BufferManager; class DataChunk; class TupleDataCollection; class TupleDataLayout; +class SortedRunScanState { +public: + SortedRunScanState(ClientContext &context, const Sort &sort); + +public: + void Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, DataChunk &chunk); + +private: + template + void TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk); + +private: + const Sort &sort; + ExpressionExecutor key_executor; + DataChunk key; + DataChunk decoded_key; + TupleDataScanState payload_state; + vector key_buffer; +}; + class SortedRun { public: - SortedRun(ClientContext &context, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort); + SortedRun(ClientContext &context, const Sort &sort, bool is_index_sort); unique_ptr CreateRunForMaterialization() const; ~SortedRun(); @@ -36,8 +59,13 @@ class SortedRun { //! Size of this sorted run idx_t SizeInBytes() const; +private: + mutex merger_global_state_lock; + unique_ptr merge_global_state; + public: ClientContext &context; + const Sort &sort; //! Key and payload collections (and associated append states) unique_ptr key_data; diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp index 21a56df83..fd894d698 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp @@ -9,10 +9,10 @@ #pragma once #include "duckdb/execution/physical_operator_states.hpp" -#include "duckdb/common/sorting/sort_projection_column.hpp" namespace duckdb { +class Sort; class TupleDataLayout; struct BoundOrderByNode; struct ProgressData; @@ -24,9 +24,7 @@ class SortedRunMerger { friend class SortedRunMergerGlobalState; public: - SortedRunMerger(const Expression &decode_sort_key, shared_ptr key_layout, - vector> &&sorted_runs, - const vector &output_projection_columns, idx_t partition_size, bool external, + SortedRunMerger(const Sort &sort, vector> &&sorted_runs, idx_t partition_size, bool external, bool is_index_sort); public: @@ -44,14 +42,12 @@ class SortedRunMerger { //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// - SourceResultType MaterializeMerge(ExecutionContext &context, OperatorSourceInput &input) const; - unique_ptr GetMaterialized(GlobalSourceState &global_state); + SourceResultType MaterializeSortedRun(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetSortedRun(GlobalSourceState &global_state); public: - const Expression &decode_sort_key; - shared_ptr key_layout; + const Sort &sort; vector> sorted_runs; - const vector &output_projection_columns; const idx_t total_count; const idx_t partition_size; diff --git a/src/duckdb/src/include/duckdb/common/string_map_set.hpp b/src/duckdb/src/include/duckdb/common/string_map_set.hpp index 00600c421..40bd51171 100644 --- a/src/duckdb/src/include/duckdb/common/string_map_set.hpp +++ b/src/duckdb/src/include/duckdb/common/string_map_set.hpp @@ -28,9 +28,26 @@ struct StringEquality { } }; +struct StringCIHash { + std::size_t operator()(const string_t &k) const { + return StringUtil::CIHash(k.GetData(), k.GetSize()); + } +}; + +struct StringCIEquality { + bool operator()(const string_t &a, const string_t &b) const { + return StringUtil::CIEquals(a.GetData(), a.GetSize(), b.GetData(), b.GetSize()); + } +}; + template using string_map_t = unordered_map; using string_set_t = unordered_set; +template +using case_insensitive_string_map_t = unordered_map; + +using case_insensitive_string_set_t = unordered_set; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp index 8c0c19bef..97d212f88 100644 --- a/src/duckdb/src/include/duckdb/common/string_util.hpp +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -217,6 +217,7 @@ class StringUtil { //! Case insensitive hash DUCKDB_API static uint64_t CIHash(const string &str); + DUCKDB_API static uint64_t CIHash(const char *str, idx_t size); //! Case insensitive equals DUCKDB_API static bool CIEquals(const string &l1, const string &l2); @@ -299,6 +300,17 @@ class StringUtil { } return strcmp(s1, s2) == 0; } + static bool Equals(const string &s1, const char *s2) { + return Equals(s1.c_str(), s2); + } + static bool Equals(const char *s1, const string &s2) { + return Equals(s1, s2.c_str()); + } + static bool Equals(const string &s1, const string &s2) { + return s1 == s2; + } + static bool Equals(const string_t &s1, const char *s2); + static bool Equals(const char *s1, const string_t &s2); //! JSON method that parses a { string: value } JSON blob //! NOTE: this method is not efficient @@ -318,6 +330,8 @@ class StringUtil { //! Transforms an complex JSON to a JSON string DUCKDB_API static string ToComplexJSONMap(const ComplexJSON &complex_json); + DUCKDB_API static string ValidateJSON(const char *data, const idx_t &len); + DUCKDB_API static string GetFileName(const string &file_path); DUCKDB_API static string GetFileExtension(const string &file_name); DUCKDB_API static string GetFileStem(const string &file_name); diff --git a/src/duckdb/src/include/duckdb/common/thread.hpp b/src/duckdb/src/include/duckdb/common/thread.hpp index 7540dfc50..6c15a662b 100644 --- a/src/duckdb/src/include/duckdb/common/thread.hpp +++ b/src/duckdb/src/include/duckdb/common/thread.hpp @@ -8,8 +8,21 @@ #pragma once +#ifndef DUCKDB_NO_THREADS #include +#include "duckdb/common/typedefs.hpp" namespace duckdb { using std::thread; + } + +#endif + +namespace duckdb { + +struct ThreadUtil { + static void SleepMs(idx_t ms); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp b/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp new file mode 100644 index 000000000..63d87e77e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/tree_renderer/mermaid_tree_renderer.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/mermaid_tree_renderer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/profiling_node.hpp" +#include "duckdb/common/tree_renderer.hpp" +#include "duckdb/common/render_tree.hpp" + +namespace duckdb { +class LogicalOperator; +class PhysicalOperator; +class Pipeline; +struct PipelineRenderNode; + +class MermaidTreeRenderer : public TreeRenderer { +public: + explicit MermaidTreeRenderer() { + } + ~MermaidTreeRenderer() override { + } + +public: + string ToString(const LogicalOperator &op); + string ToString(const PhysicalOperator &op); + string ToString(const ProfilingNode &op); + string ToString(const Pipeline &op); + + void Render(const LogicalOperator &op, std::ostream &ss); + void Render(const PhysicalOperator &op, std::ostream &ss); + void Render(const ProfilingNode &op, std::ostream &ss) override; + void Render(const Pipeline &op, std::ostream &ss); + + void ToStreamInternal(RenderTree &root, std::ostream &ss) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/type_util.hpp b/src/duckdb/src/include/duckdb/common/type_util.hpp index 40a3eb872..8c0e7ddc9 100644 --- a/src/duckdb/src/include/duckdb/common/type_util.hpp +++ b/src/duckdb/src/include/duckdb/common/type_util.hpp @@ -22,60 +22,62 @@ struct bignum_t; //! Returns the PhysicalType for the given type template PhysicalType GetTypeId() { - if (std::is_same()) { + using TYPE = typename std::remove_cv::type; + + if (std::is_same()) { return PhysicalType::BOOL; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT8; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT16; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT8; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT16; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT64; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::UINT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT128; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::UINT128; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT32; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INT64; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::FLOAT; - } else if (std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same()) { return PhysicalType::DOUBLE; - } else if (std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + } else if (std::is_same() || std::is_same() || std::is_same() || + std::is_same()) { return PhysicalType::VARCHAR; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::INTERVAL; - } else if (std::is_same()) { + } else if (std::is_same()) { return PhysicalType::LIST; - } else if (std::is_pointer() || std::is_same()) { + } else if (std::is_pointer() || std::is_same()) { if (sizeof(uintptr_t) == sizeof(uint32_t)) { return PhysicalType::UINT32; } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { @@ -90,10 +92,12 @@ PhysicalType GetTypeId() { template bool StorageTypeCompatible(PhysicalType type) { - if (std::is_same()) { + using TYPE = typename std::remove_cv::type; + + if (std::is_same()) { return type == PhysicalType::INT8 || type == PhysicalType::BOOL; } - if (std::is_same()) { + if (std::is_same()) { return type == PhysicalType::UINT8 || type == PhysicalType::BOOL; } return type == GetTypeId(); @@ -101,8 +105,10 @@ bool StorageTypeCompatible(PhysicalType type) { template bool TypeIsNumber() { - return std::is_integral() || std::is_floating_point() || std::is_same() || - std::is_same(); + using TYPE = typename std::remove_cv::type; + + return std::is_integral() || std::is_floating_point() || std::is_same() || + std::is_same(); } template diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index 0f7ddbb2d..6d85ce2de 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -230,6 +230,8 @@ enum class LogicalTypeId : uint8_t { VALIDITY = 53, UUID = 54, + GEOMETRY = 60, + STRUCT = 100, LIST = 101, MAP = 102, @@ -430,6 +432,7 @@ struct LogicalType { DUCKDB_API static LogicalType UNION(child_list_t members); // NOLINT DUCKDB_API static LogicalType ARRAY(const LogicalType &child, optional_idx index); // NOLINT DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType GEOMETRY(); // NOLINT // ANY but with special rules (default is LogicalType::ANY, 5) DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp index 8a5cc9e19..4f1d8f2b0 100644 --- a/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/batched_data_collection.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/common/types/batched_chunk_collection.hpp +// duckdb/common/types/batched_data_collection.hpp // // //===----------------------------------------------------------------------===// @@ -10,8 +10,10 @@ #include "duckdb/common/map.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/main/query_parameters.hpp" namespace duckdb { + class BufferManager; class ClientContext; @@ -32,9 +34,16 @@ struct BatchedChunkScanState { //! Scans over a BatchedDataCollection are ordered by batch index class BatchedDataCollection { public: - DUCKDB_API BatchedDataCollection(ClientContext &context, vector types, bool buffer_managed = false); - DUCKDB_API BatchedDataCollection(ClientContext &context, vector types, batch_map_t batches, - bool buffer_managed = false); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, + ColumnDataAllocatorType allocator_type = ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, QueryResultMemoryType memory_type); + DUCKDB_API + BatchedDataCollection(ClientContext &context, vector types, batch_map_t batches, + ColumnDataAllocatorType allocator_type = ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Appends a datachunk with the given batch index to the batched collection DUCKDB_API void Append(DataChunk &input, idx_t batch_index); @@ -79,6 +88,8 @@ class BatchedDataCollection { DUCKDB_API void Print() const; private: + unique_ptr CreateCollection() const; + struct CachedCollection { idx_t batch_index = DConstants::INVALID_INDEX; ColumnDataCollection *collection = nullptr; @@ -87,7 +98,8 @@ class BatchedDataCollection { ClientContext &context; vector types; - bool buffer_managed; + ColumnDataAllocatorType allocator_type; + ColumnDataCollectionLifetime lifetime; //! The data of the batched chunk collection - a set of batch_index -> ColumnDataCollection pointers map> data; //! The last batch collection that was inserted into diff --git a/src/duckdb/src/include/duckdb/common/types/bit.hpp b/src/duckdb/src/include/duckdb/common/types/bit.hpp index cbf599139..c1d17095d 100644 --- a/src/duckdb/src/include/duckdb/common/types/bit.hpp +++ b/src/duckdb/src/include/duckdb/common/types/bit.hpp @@ -101,8 +101,9 @@ template void Bit::NumericToBit(T numeric, bitstring_t &output_str) { D_ASSERT(output_str.GetSize() >= sizeof(T) + 1); + auto le_numeric = BSwapIfBE(numeric); auto output = output_str.GetDataWriteable(); - auto data = const_data_ptr_cast(&numeric); + auto data = const_data_ptr_cast(&le_numeric); *output = 0; // set padding to 0 ++output; @@ -141,6 +142,7 @@ void Bit::BitToNumeric(bitstring_t bit, T &output_num) { for (idx_t idx = padded_byte_idx + 1; idx < sizeof(T); ++idx) { output[sizeof(T) - 1 - idx] = data[1 + idx - padded_byte_idx]; } + output_num = BSwapIfBE(output_num); } } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp index 564ca5c09..6f2c8a1c1 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -17,21 +18,31 @@ struct VectorMetaData; struct SwizzleMetaData; struct BlockMetaData { - //! The underlying block handle - shared_ptr handle; +public: //! How much space is currently used within the block uint32_t size; //! How much space is available in the block uint32_t capacity; +private: + //! The underlying block handle + shared_ptr handle; + //! Weak pointer to underlying block handle (if ColumnDataCollectionLifetime::DATABASE_INSTANCE) + weak_ptr weak_handle; + +public: + shared_ptr GetHandle() const; + void SetHandle(ManagedResultSet &managed_result_set, shared_ptr handle); uint32_t Capacity(); }; class ColumnDataAllocator { public: explicit ColumnDataAllocator(Allocator &allocator); - explicit ColumnDataAllocator(BufferManager &buffer_manager); - ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type); + explicit ColumnDataAllocator(BufferManager &buffer_manager, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); + ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); ColumnDataAllocator(ColumnDataAllocator &allocator); ~ColumnDataAllocator(); @@ -81,6 +92,8 @@ class ColumnDataAllocator { //! Prevents the block with the given id from being added to the eviction queue void SetDestroyBufferUponUnpin(uint32_t block_id); + //! Gets a shared pointer to the database instance if ColumnDataCollectionLifetime::DATABASE_INSTANCE + shared_ptr GetDatabase() const; private: void AllocateEmptyBlock(idx_t size); @@ -116,6 +129,8 @@ class ColumnDataAllocator { idx_t allocated_size = 0; //! Partition index (optional, if partitioned) optional_idx partition_index; + //! Lifetime management for this allocator + ManagedResultSet managed_result_set; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp index f02d49001..6cb1d7bdd 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp @@ -8,10 +8,10 @@ #pragma once -#include "duckdb/common/pair.hpp" #include "duckdb/common/types/column/column_data_collection_iterators.hpp" namespace duckdb { + class BufferManager; class BlockHandle; class ClientContext; @@ -30,10 +30,14 @@ class ColumnDataCollection { //! Constructs an empty (but valid) in-memory column data collection from an allocator DUCKDB_API explicit ColumnDataCollection(Allocator &allocator); //! Constructs a buffer-managed column data collection - DUCKDB_API ColumnDataCollection(BufferManager &buffer_manager, vector types); + DUCKDB_API + ColumnDataCollection(BufferManager &buffer_manager, vector types, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Constructs either an in-memory or a buffer-managed column data collection - DUCKDB_API ColumnDataCollection(ClientContext &context, vector types, - ColumnDataAllocatorType type = ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); + DUCKDB_API + ColumnDataCollection(ClientContext &context, vector types, + ColumnDataAllocatorType type = ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR, + ColumnDataCollectionLifetime lifetime = ColumnDataCollectionLifetime::REGULAR); //! Creates a column data collection that inherits the blocks to write to. This allows blocks to be shared //! between multiple column data collections and prevents wasting space. //! Note that after one CDC inherits blocks from another, the other @@ -78,6 +82,7 @@ class ColumnDataCollection { //! Initializes a chunk with the correct types that can be used to call Scan DUCKDB_API void InitializeScanChunk(DataChunk &chunk) const; + DUCKDB_API void InitializeScanChunk(Allocator &allocator, DataChunk &chunk) const; //! Initializes a chunk with the correct types for a given scan state DUCKDB_API void InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const; //! Initializes a Scan state for scanning all columns @@ -161,6 +166,8 @@ class ColumnDataCollection { vector> GetHeapReferences(); //! Get the allocator type of this ColumnDataCollection ColumnDataAllocatorType GetAllocatorType() const; + //! Get the buffer manager of the allocator + BufferManager &GetBufferManager() const; //! Get a vector of the segments in this ColumnDataCollection const vector> &GetSegments() const; @@ -194,7 +201,9 @@ class ColumnDataCollection { //! The ColumnDataRowCollection represents a set of materialized rows, as obtained from the ColumnDataCollection class ColumnDataRowCollection { public: - DUCKDB_API explicit ColumnDataRowCollection(const ColumnDataCollection &collection); + DUCKDB_API explicit ColumnDataRowCollection( + const ColumnDataCollection &collection, + ColumnDataScanProperties properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY); public: DUCKDB_API Value GetValue(idx_t column, idx_t index) const; diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp index b84b81d47..ff42eadf2 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_iterators.hpp @@ -63,7 +63,9 @@ class ColumnDataRowIterationHelper { class ColumnDataRowIterator { public: - DUCKDB_API explicit ColumnDataRowIterator(const ColumnDataCollection *collection_p); + DUCKDB_API explicit ColumnDataRowIterator( + const ColumnDataCollection *collection_p, + ColumnDataScanProperties properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY); const ColumnDataCollection *collection; ColumnDataScanState scan_state; diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp index c809520c6..d544db851 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_scan_states.hpp @@ -35,6 +35,14 @@ enum class ColumnDataScanProperties : uint8_t { DISALLOW_ZERO_COPY }; +enum class ColumnDataCollectionLifetime { + //! Regular lifetime management + REGULAR, + //! Accessing will throw an error after the DB closes + //! Optional for ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR only + THROW_ERROR_AFTER_DATABASE_CLOSES, +}; + struct ChunkManagementState { unordered_map handles; ColumnDataScanProperties properties = ColumnDataScanProperties::INVALID; @@ -46,6 +54,9 @@ struct ColumnDataAppendState { }; struct ColumnDataScanState { + //! Database instance if scanning ColumnDataCollectionLifetime::DATABASE_INSTANCE + shared_ptr db; + ChunkManagementState current_chunk_state; idx_t segment_index; idx_t chunk_index; diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp new file mode 100644 index 000000000..1fcc9d7a1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -0,0 +1,224 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/geometry.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/pair.hpp" +#include +#include + +namespace duckdb { + +struct GeometryStatsData; + +enum class GeometryType : uint8_t { + INVALID = 0, + POINT = 1, + LINESTRING = 2, + POLYGON = 3, + MULTIPOINT = 4, + MULTILINESTRING = 5, + MULTIPOLYGON = 6, + GEOMETRYCOLLECTION = 7, +}; + +enum class VertexType : uint8_t { XY = 0, XYZ = 1, XYM = 2, XYZM = 3 }; + +struct VertexXY { + static constexpr auto TYPE = VertexType::XY; + static constexpr auto HAS_Z = false; + static constexpr auto HAS_M = false; + + double x; + double y; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y); + } +}; + +struct VertexXYZ { + static constexpr auto TYPE = VertexType::XYZ; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = false; + + double x; + double y; + double z; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z); + } +}; +struct VertexXYM { + static constexpr auto TYPE = VertexType::XYM; + static constexpr auto HAS_M = true; + static constexpr auto HAS_Z = false; + + double x; + double y; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(m); + } +}; + +struct VertexXYZM { + static constexpr auto TYPE = VertexType::XYZM; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = true; + + double x; + double y; + double z; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z) && std::isnan(m); + } +}; + +class GeometryExtent { +public: + static constexpr auto UNKNOWN_MIN = -std::numeric_limits::infinity(); + static constexpr auto UNKNOWN_MAX = +std::numeric_limits::infinity(); + + static constexpr auto EMPTY_MIN = +std::numeric_limits::infinity(); + static constexpr auto EMPTY_MAX = -std::numeric_limits::infinity(); + + // "Unknown" extent means we don't know the bounding box. + // Merging with an unknown extent results in an unknown extent. + // Everything intersects with an unknown extent. + static GeometryExtent Unknown() { + return GeometryExtent {UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, + UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX}; + } + + // "Empty" extent means the smallest possible bounding box. + // Merging with an empty extent has no effect. + // Nothing intersects with an empty extent. + static GeometryExtent Empty() { + return GeometryExtent {EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX}; + } + + // Does this extent have any X/Y values set? + // In other words, is the range of the x/y axes not empty and not unknown? + bool HasXY() const { + return std::isfinite(x_min) && std::isfinite(y_min) && std::isfinite(x_max) && std::isfinite(y_max); + } + // Does this extent have any Z values set? + // In other words, is the range of the Z-axis not empty and not unknown? + bool HasZ() const { + return std::isfinite(z_min) && std::isfinite(z_max); + } + // Does this extent have any M values set? + // In other words, is the range of the M-axis not empty and not unknown? + bool HasM() const { + return std::isfinite(m_min) && std::isfinite(m_max); + } + + void Extend(const VertexXY &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + } + + void Extend(const VertexXYZ &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + } + + void Extend(const VertexXYM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Extend(const VertexXYZM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Merge(const GeometryExtent &other) { + x_min = MinValue(x_min, other.x_min); + y_min = MinValue(y_min, other.y_min); + z_min = MinValue(z_min, other.z_min); + m_min = MinValue(m_min, other.m_min); + + x_max = MaxValue(x_max, other.x_max); + y_max = MaxValue(y_max, other.y_max); + z_max = MaxValue(z_max, other.z_max); + m_max = MaxValue(m_max, other.m_max); + } + + bool IntersectsXY(const GeometryExtent &other) const { + return !(x_min > other.x_max || x_max < other.x_min || y_min > other.y_max || y_max < other.y_min); + } + + bool IntersectsXYZM(const GeometryExtent &other) const { + return !(x_min > other.x_max || x_max < other.x_min || y_min > other.y_max || y_max < other.y_min || + z_min > other.z_max || z_max < other.z_min || m_min > other.m_max || m_max < other.m_min); + } + + bool ContainsXY(const GeometryExtent &other) const { + return x_min <= other.x_min && x_max >= other.x_max && y_min <= other.y_min && y_max >= other.y_max; + } + + double x_min; + double y_min; + double z_min; + double m_min; + + double x_max; + double y_max; + double z_max; + double m_max; +}; + +class Geometry { +public: + static constexpr idx_t MAX_RECURSION_DEPTH = 16; + + //! Convert from WKT + DUCKDB_API static bool FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict); + + //! Convert to WKT + DUCKDB_API static string_t ToString(Vector &result, const string_t &geom); + + //! Convert from WKB + DUCKDB_API static bool FromBinary(const string_t &wkb, string_t &result, Vector &result_vector, bool strict); + DUCKDB_API static bool FromBinary(Vector &source, Vector &result, idx_t count, bool strict); + + //! Convert to WKB + DUCKDB_API static void ToBinary(Vector &source, Vector &result, idx_t count); + + //! Get the geometry type and vertex type from the WKB + DUCKDB_API static pair GetType(const string_t &wkb); + + //! Update the bounding box, return number of vertices processed + DUCKDB_API static uint32_t GetExtent(const string_t &wkb, GeometryExtent &extent); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp index 3720bf844..9fa5d447b 100644 --- a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp @@ -129,38 +129,38 @@ class Hugeint { static int Sign(hugeint_t n); static hugeint_t Abs(hugeint_t n); // comparison operators - static bool Equals(hugeint_t lhs, hugeint_t rhs) { + static bool Equals(const hugeint_t &lhs, const hugeint_t &rhs) { bool lower_equals = lhs.lower == rhs.lower; bool upper_equals = lhs.upper == rhs.upper; return lower_equals && upper_equals; } - static bool NotEquals(hugeint_t lhs, hugeint_t rhs) { + static bool NotEquals(const hugeint_t &lhs, const hugeint_t &rhs) { return !Equals(lhs, rhs); } - static bool GreaterThan(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger = lhs.lower > rhs.lower; return upper_bigger || (upper_equal && lower_bigger); } - static bool GreaterThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger_equals = lhs.lower >= rhs.lower; return upper_bigger || (upper_equal && lower_bigger_equals); } - static bool LessThan(hugeint_t lhs, hugeint_t rhs) { + static bool LessThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller = lhs.lower < rhs.lower; return upper_smaller || (upper_equal && lower_smaller); } - static bool LessThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool LessThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller_equals = lhs.lower <= rhs.lower; diff --git a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp index c29b094a8..8babd751c 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp @@ -23,64 +23,96 @@ enum class BlockIteratorStateType : int8_t { EXTERNAL, }; -BlockIteratorStateType GetBlockIteratorStateType(const bool &external); +template +class BlockIteratorStateBase { +protected: + friend BLOCK_ITERATOR_STATE; -//! State for iterating over blocks of an in-memory TupleDataCollection -//! Multiple iterators can share the same state, everything is const -class InMemoryBlockIteratorState { -public: - explicit InMemoryBlockIteratorState(const TupleDataCollection &key_data); - -public: - template - T &GetValueAtIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { - D_ASSERT(GetIndex(block_idx, tuple_idx) < tuple_count); - return reinterpret_cast(block_ptrs[block_idx])[tuple_idx]; +private: + explicit BlockIteratorStateBase(const idx_t tuple_count_p) : tuple_count(tuple_count_p) { } - template - T &GetValueAtIndex(const idx_t &n) const { - const auto quotient = fast_mod.Div(n); - return GetValueAtIndex(quotient, fast_mod.Mod(n, quotient)); +public: + idx_t GetDivisor() const { + const auto &state = static_cast(*this); + return state.GetDivisor(); } - void RandomAccess(idx_t &block_idx, idx_t &tuple_idx, const idx_t &index) const { - block_idx = fast_mod.Div(index); - tuple_idx = fast_mod.Mod(index, block_idx); + void RandomAccess(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &index) const { + const auto &state = static_cast(*this); + state.RandomAccessInternal(block_or_chunk_idx, tuple_idx, index); } - void Add(idx_t &block_idx, idx_t &tuple_idx, const idx_t &value) const { + void Add(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &value) const { tuple_idx += value; - if (tuple_idx >= fast_mod.GetDivisor()) { - const auto div = fast_mod.Div(tuple_idx); - tuple_idx -= div * fast_mod.GetDivisor(); - block_idx += div; + if (tuple_idx >= GetDivisor()) { + RandomAccess(block_or_chunk_idx, tuple_idx, GetIndex(block_or_chunk_idx, tuple_idx)); } } - void Subtract(idx_t &block_idx, idx_t &tuple_idx, const idx_t &value) const { + void Subtract(idx_t &block_or_chunk_idx, idx_t &tuple_idx, const idx_t &value) const { tuple_idx -= value; - if (tuple_idx >= fast_mod.GetDivisor()) { - const auto div = fast_mod.Div(-tuple_idx); - tuple_idx += (div + 1) * fast_mod.GetDivisor(); - block_idx -= div + 1; + if (tuple_idx >= GetDivisor()) { + RandomAccess(block_or_chunk_idx, tuple_idx, GetIndex(block_or_chunk_idx, tuple_idx)); } } - void Increment(idx_t &block_idx, idx_t &tuple_idx) const { - const auto passed_boundary = ++tuple_idx == fast_mod.GetDivisor(); - block_idx += passed_boundary; - tuple_idx *= !passed_boundary; + void Increment(idx_t &block_or_chunk_idx, idx_t &tuple_idx) const { + const auto crossed_boundary = ++tuple_idx == GetDivisor(); + block_or_chunk_idx += crossed_boundary; + tuple_idx *= !crossed_boundary; } - void Decrement(idx_t &block_idx, idx_t &tuple_idx) const { + void Decrement(idx_t &block_or_chunk_idx, idx_t &tuple_idx) const { const auto crossed_boundary = tuple_idx-- == 0; - block_idx -= crossed_boundary; - tuple_idx += crossed_boundary * fast_mod.GetDivisor(); + block_or_chunk_idx -= crossed_boundary; + tuple_idx += crossed_boundary * GetDivisor(); + } + + idx_t GetIndex(const idx_t &block_or_chunk_idx, const idx_t &tuple_idx) const { + return block_or_chunk_idx * GetDivisor() + tuple_idx; + } + +protected: + const idx_t tuple_count; +}; + +template +class BlockIteratorState; + +//! State for iterating over blocks of an in-memory TupleDataCollection +//! Multiple iterators can share the same state, everything is const +template <> +class BlockIteratorState + : public BlockIteratorStateBase> { +public: + explicit BlockIteratorState(const TupleDataCollection &key_data) + : BlockIteratorStateBase(key_data.Count()), block_ptrs(ConvertBlockPointers(key_data.GetRowBlockPointers())), + fast_mod(key_data.TuplesPerBlock()) { + } + +public: + idx_t GetDivisor() const { + return fast_mod.GetDivisor(); + } + + void RandomAccessInternal(idx_t &block_idx, idx_t &tuple_idx, const idx_t &index) const { + block_idx = fast_mod.Div(index); + tuple_idx = fast_mod.Mod(index, block_idx); + } + + template + T &GetValueAtIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { + D_ASSERT(GetIndex(block_idx, tuple_idx) < tuple_count); + return reinterpret_cast(block_ptrs[block_idx])[tuple_idx]; } - idx_t GetIndex(const idx_t &block_idx, const idx_t &tuple_idx) const { - return block_idx * fast_mod.GetDivisor() + tuple_idx; + template + T &GetValueAtIndex(const idx_t &index) const { + idx_t block_idx; + idx_t tuple_idx; + RandomAccess(block_idx, tuple_idx, index); + return GetValueAtIndex(block_idx, tuple_idx); } void SetKeepPinned(const bool &) { @@ -92,72 +124,63 @@ class InMemoryBlockIteratorState { } private: - static unsafe_vector ConvertBlockPointers(const vector &block_ptrs); + static unsafe_vector ConvertBlockPointers(const vector &block_ptrs) { + unsafe_vector converted_block_ptrs; + converted_block_ptrs.reserve(block_ptrs.size()); + for (const auto &block_ptr : block_ptrs) { + converted_block_ptrs.emplace_back(block_ptr); + } + return converted_block_ptrs; + } private: const unsafe_vector block_ptrs; const FastMod fast_mod; - const idx_t tuple_count; }; +using InMemoryBlockIteratorState = BlockIteratorState; + //! State for iterating over blocks of an external (larger-than-memory) TupleDataCollection //! This state cannot be shared by multiple iterators, it is stateful -class ExternalBlockIteratorState { -public: - explicit ExternalBlockIteratorState(TupleDataCollection &key_data, optional_ptr payload_data); - +template <> +class BlockIteratorState + : public BlockIteratorStateBase> { public: - template - T &GetValueAtIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { - if (chunk_idx != current_chunk_idx) { - InitializeChunk(chunk_idx); + explicit BlockIteratorState(TupleDataCollection &key_data_p, optional_ptr payload_data_p) + : BlockIteratorStateBase(key_data_p.Count()), current_chunk_idx(DConstants::INVALID_INDEX), + key_data(key_data_p), key_ptrs(FlatVector::GetData(key_scan_state.chunk_state.row_locations)), + payload_data(payload_data_p), keep_pinned(false), pin_payload(false) { + key_data.InitializeScan(key_scan_state); + if (payload_data) { + payload_data->InitializeScan(payload_scan_state); } - return *reinterpret_cast(key_ptrs)[tuple_idx]; } - template - T &GetValueAtIndex(const idx_t &n) { - D_ASSERT(n < tuple_count); - return GetValueAtIndex(n / STANDARD_VECTOR_SIZE, n % STANDARD_VECTOR_SIZE); +public: + static constexpr idx_t GetDivisor() { + return STANDARD_VECTOR_SIZE; } - static void RandomAccess(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &index) { + static void RandomAccessInternal(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &index) { chunk_idx = index / STANDARD_VECTOR_SIZE; tuple_idx = index % STANDARD_VECTOR_SIZE; } - static void Add(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &value) { - tuple_idx += value; - if (tuple_idx >= STANDARD_VECTOR_SIZE) { - const auto div = tuple_idx / STANDARD_VECTOR_SIZE; - tuple_idx -= div * STANDARD_VECTOR_SIZE; - chunk_idx += div; - } - } - - static void Subtract(idx_t &chunk_idx, idx_t &tuple_idx, const idx_t &value) { - tuple_idx -= value; - if (tuple_idx >= STANDARD_VECTOR_SIZE) { - const auto div = -tuple_idx / STANDARD_VECTOR_SIZE; - tuple_idx += (div + 1) * STANDARD_VECTOR_SIZE; - chunk_idx -= div + 1; + template + T &GetValueAtIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { + D_ASSERT(GetIndex(chunk_idx, tuple_idx) < tuple_count); + if (chunk_idx != current_chunk_idx) { + InitializeChunk(chunk_idx); } + return *reinterpret_cast(key_ptrs)[tuple_idx]; } - static void Increment(idx_t &chunk_idx, idx_t &tuple_idx) { - const auto passed_boundary = ++tuple_idx == STANDARD_VECTOR_SIZE; - chunk_idx += passed_boundary; - tuple_idx *= !passed_boundary; - } - - static void Decrement(idx_t &chunk_idx, idx_t &tuple_idx) { - const auto crossed_boundary = tuple_idx-- == 0; - chunk_idx -= crossed_boundary; - tuple_idx += crossed_boundary * static_cast(STANDARD_VECTOR_SIZE); - } - - static idx_t GetIndex(const idx_t &chunk_idx, const idx_t &tuple_idx) { - return chunk_idx * STANDARD_VECTOR_SIZE + tuple_idx; + template + T &GetValueAtIndex(const idx_t &index) { + idx_t chunk_idx; + idx_t tuple_idx; + RandomAccess(chunk_idx, tuple_idx, index); + return GetValueAtIndex(chunk_idx, tuple_idx); } void SetKeepPinned(const bool &enable) { @@ -183,25 +206,18 @@ class ExternalBlockIteratorState { key_scan_state.pin_state.row_handles.acquire_handles(pins); key_scan_state.pin_state.heap_handles.acquire_handles(pins); } - key_data.FetchChunk(key_scan_state, 0, chunk_idx, false); + key_data.FetchChunk(key_scan_state, chunk_idx, false); if (pin_payload && payload_data) { if (keep_pinned) { payload_scan_state.pin_state.row_handles.acquire_handles(pins); payload_scan_state.pin_state.heap_handles.acquire_handles(pins); } - const auto chunk_count = payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); - const auto sort_keys = reinterpret_cast(key_ptrs); - payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); - const auto payload_ptrs = FlatVector::GetData(payload_scan_state.chunk_state.row_locations); - for (idx_t i = 0; i < chunk_count; i++) { - sort_keys[i]->SetPayload(payload_ptrs[i]); - D_ASSERT(GetValueAtIndex(chunk_idx, i).GetPayload() == payload_ptrs[i]); - } + SortKeyPayloadState skp_state {key_scan_state.chunk_state, key_data.GetLayout().GetSortKeyType()}; + payload_data->FetchChunk(payload_scan_state, chunk_idx, false, &skp_state); } } private: - const idx_t tuple_count; idx_t current_chunk_idx; TupleDataCollection &key_data; @@ -216,13 +232,7 @@ class ExternalBlockIteratorState { vector pins; }; -//! Utility so we can get the state using the type -template -using BlockIteratorState = typename std::conditional< - T == BlockIteratorStateType::IN_MEMORY, InMemoryBlockIteratorState, - typename std::conditional::type>::type; +using ExternalBlockIteratorState = BlockIteratorState; //! Iterator for data spread out over multiple blocks template @@ -305,16 +315,16 @@ class block_iterator_t { // NOLINT: match stl case return *this; } block_iterator_t operator+(const difference_type &n) const { - idx_t new_block_idx = block_or_chunk_idx; + idx_t new_block_or_chunk_idx = block_or_chunk_idx; idx_t new_tuple_idx = tuple_idx; - state->Add(new_block_idx, new_tuple_idx, n); - return block_iterator_t(*state, new_block_idx, new_tuple_idx); + state->Add(new_block_or_chunk_idx, new_tuple_idx, n); + return block_iterator_t(*state, new_block_or_chunk_idx, new_tuple_idx); } block_iterator_t operator-(const difference_type &n) const { - idx_t new_block_idx = block_or_chunk_idx; + idx_t new_block_or_chunk_idx = block_or_chunk_idx; idx_t new_tuple_idx = tuple_idx; - state->Subtract(new_block_idx, new_tuple_idx, n); - return block_iterator_t(*state, new_block_idx, new_tuple_idx); + state->Subtract(new_block_or_chunk_idx, new_tuple_idx, n); + return block_iterator_t(*state, new_block_or_chunk_idx, new_tuple_idx); } reference operator[](const difference_type &n) const { diff --git a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp index 42e68e9ef..882049ddf 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp @@ -161,8 +161,8 @@ class PartitionedTupleData { protected: //! PartitionedTupleData can only be instantiated by derived classes PartitionedTupleData(PartitionedTupleDataType type, BufferManager &buffer_manager, - shared_ptr &layout_ptr); - PartitionedTupleData(const PartitionedTupleData &other); + shared_ptr &layout_ptr, MemoryTag tag); + PartitionedTupleData(PartitionedTupleData &other); //! Whether to use fixed size map or regular map bool UseFixedSizeMap() const; @@ -178,17 +178,23 @@ class PartitionedTupleData { template void BuildBufferSpace(PartitionedTupleDataAppendState &state); //! Create a collection for a specific a partition - unique_ptr CreatePartitionCollection(idx_t partition_index) { - return make_uniq(buffer_manager, layout_ptr); + unique_ptr CreatePartitionCollection() { + return make_uniq(buffer_manager, layout_ptr, tag, stl_allocator); } //! Verify count/data size of this PartitionedTupleData void Verify() const; protected: PartitionedTupleDataType type; + BufferManager &buffer_manager; + shared_ptr stl_allocator; + shared_ptr layout_ptr; const TupleDataLayout &layout; + + const MemoryTag tag; + idx_t count; idx_t data_size; diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp index c603baac0..d2578fbd7 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" namespace duckdb { @@ -20,7 +21,7 @@ class ContinuousIdSet; struct TupleDataBlock { public: - TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p); + TupleDataBlock(BufferManager &buffer_manager, MemoryTag tag, idx_t capacity_p); //! Disable copy constructors TupleDataBlock(const TupleDataBlock &other) = delete; @@ -53,7 +54,8 @@ struct TupleDataBlock { class TupleDataAllocator { public: - TupleDataAllocator(BufferManager &buffer_manager, shared_ptr &layout_ptr); + TupleDataAllocator(BufferManager &buffer_manager, shared_ptr layout_ptr, MemoryTag tag, + shared_ptr stl_allocator); TupleDataAllocator(TupleDataAllocator &allocator); ~TupleDataAllocator(); @@ -62,6 +64,8 @@ class TupleDataAllocator { BufferManager &GetBufferManager(); //! Get the buffer allocator Allocator &GetAllocator(); + //! Get the STL allocator + ArenaAllocator &GetStlAllocator(); //! Get the layout shared_ptr GetLayoutPtr() const; const TupleDataLayout &GetLayout() const; @@ -80,7 +84,8 @@ class TupleDataAllocator { const idx_t append_offset, const idx_t append_count); //! Initializes a chunk, making its pointers valid void InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - idx_t chunk_idx, bool init_heap); + idx_t chunk_idx, bool init_heap, + optional_ptr sort_key_payload_state = nullptr); static void RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, const data_ptr_t row_locations[], Vector &new_heap_ptrs, const idx_t offset, const idx_t count, const TupleDataLayout &layout, const idx_t base_col_offset); @@ -99,17 +104,23 @@ class TupleDataAllocator { private: //! Builds out a single part (grabs the lock) - TupleDataChunkPart BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, TupleDataChunk &chunk); + unsafe_arena_ptr BuildChunkPart(TupleDataSegment &segment, TupleDataPinState &pin_state, + TupleDataChunkState &chunk_state, const idx_t append_offset, + const idx_t append_count, TupleDataChunk &chunk); //! Internal function for InitializeChunkState void InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, idx_t offset, bool recompute, bool init_heap_pointers, bool init_heap_sizes, - unsafe_vector> &parts); + unsafe_vector> &parts, + optional_ptr sort_key_payload_state = nullptr); //! Internal function for ReleaseOrStoreHandles static void ReleaseOrStoreHandlesInternal(TupleDataSegment &segment, - unsafe_vector &pinned_row_handles, + unsafe_arena_vector &pinned_row_handles, buffer_handle_map_t &handles, const ContinuousIdSet &block_ids, - unsafe_vector &blocks, TupleDataPinProperties properties); + unsafe_arena_vector &blocks, + TupleDataPinProperties properties); + //! Create a row/heap block, extend the pinned handles in the segment accordingly + void CreateRowBlock(TupleDataSegment &segment); + void CreateHeapBlock(TupleDataSegment &segment, idx_t size); //! Pins the given row block BufferHandle &PinRowBlock(TupleDataPinState &state, const TupleDataChunkPart &part); //! Pins the given heap block @@ -120,21 +131,21 @@ class TupleDataAllocator { data_ptr_t GetBaseHeapPointer(TupleDataPinState &state, const TupleDataChunkPart &part); private: + //! Shared allocator for STL allocations + shared_ptr stl_allocator; //! The buffer manager BufferManager &buffer_manager; //! The layout of the data shared_ptr layout_ptr; const TupleDataLayout &layout; + //! Memory tag (for keeping track what the allocated memory belongs to) + const MemoryTag tag; //! Partition index (optional, if partitioned) optional_idx partition_index; //! Blocks storing the fixed-size rows - unsafe_vector row_blocks; + unsafe_arena_vector row_blocks; //! Blocks storing the variable-size data of the fixed-size rows (e.g., string, list) - unsafe_vector heap_blocks; - - //! Re-usable arrays used while building buffer space - unsafe_vector> chunk_parts; - unsafe_vector> chunk_part_indices; + unsafe_arena_vector heap_blocks; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp index d759341ee..3e69d4421 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp @@ -49,7 +49,10 @@ class TupleDataCollection { public: //! Constructs a TupleDataCollection with the specified layout - TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr); + TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr, MemoryTag tag, + shared_ptr stl_allocator = nullptr); + TupleDataCollection(ClientContext &context, shared_ptr layout_ptr, MemoryTag tag, + shared_ptr stl_allocator = nullptr); ~TupleDataCollection(); @@ -172,7 +175,7 @@ class TupleDataCollection { //! Initializes a chunk with the correct types that can be used to call Append/Scan for the given columns void InitializeChunk(DataChunk &chunk, const vector &columns) const; //! Initializes a chunk with the correct types for a given scan state - void InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const; + void InitializeScanChunk(const TupleDataScanState &state, DataChunk &chunk) const; //! Initializes a Scan state for scanning all columns void InitializeScan(TupleDataScanState &state, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; @@ -185,14 +188,17 @@ class TupleDataCollection { //! Initialize a parallel scan over the tuple data collection over a subset of the columns void InitializeScan(TupleDataParallelScanState &gstate, vector column_ids, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Grab the chunk state for the given segment and chunk index, returns the count of the chunk - idx_t FetchChunk(TupleDataScanState &state, idx_t segment_idx, idx_t chunk_idx, bool init_heap); + //! Grab the chunk state for the given chunk index, returns the count of the chunk + idx_t FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap, + optional_ptr sort_key_payload_state = nullptr); //! Scans a DataChunk from the TupleDataCollection bool Scan(TupleDataScanState &state, DataChunk &result); //! Scans a DataChunk from the TupleDataCollection bool Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result); //! Whether the last scan has been completed on this TupleDataCollection bool ScanComplete(const TupleDataScanState &state) const; + //! Seeks to the specified chunk index, returning the total row count before it + idx_t Seek(TupleDataScanState &state, const idx_t target_chunk); //! Gathers a DataChunk from the TupleDataCollection, given the specific row locations (requires full pin) void Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, DataChunk &result, @@ -221,7 +227,7 @@ class TupleDataCollection { //! Gets all column ids void GetAllColumnIDs(vector &column_ids); //! Adds a segment to this TupleDataCollection - void AddSegment(unsafe_unique_ptr segment); + void AddSegment(unsafe_arena_ptr segment); //! Computes the heap sizes for the specific Vector that will be appended static void ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, TupleDataVectorFormat &source, @@ -262,9 +268,13 @@ class TupleDataCollection { void Verify() const; private: + //! Shared allocator for STL allocations + shared_ptr stl_allocator; //! The layout of the TupleDataCollection shared_ptr layout_ptr; const TupleDataLayout &layout; + //! Memory tag (for keeping track what the allocated memory belongs to) + const MemoryTag tag; //! The TupleDataAllocator shared_ptr allocator; //! The number of entries stored in the TupleDataCollection @@ -272,11 +282,11 @@ class TupleDataCollection { //! The size (in bytes) of this TupleDataCollection idx_t data_size; //! The data segments of the TupleDataCollection - unsafe_vector> segments; + unsafe_arena_vector> segments; //! The set of scatter functions - vector scatter_functions; + unsafe_arena_vector scatter_functions; //! The set of gather functions - vector gather_functions; + unsafe_arena_vector gather_functions; //! Partition index (optional, if partitioned) optional_idx partition_index; }; diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp index 22afdb156..93558050a 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp @@ -14,6 +14,7 @@ #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/arena_containers/arena_vector.hpp" namespace duckdb { @@ -49,7 +50,7 @@ struct TupleDataChunkPart { idx_t total_heap_size; //! Tuple count for this chunk part uint32_t count; - //! Lock for recomputing heap pointers (owned by TupleDataChunk) + //! Lock for recomputing heap pointers reference lock; private: @@ -113,7 +114,7 @@ class ContinuousIdSet { struct TupleDataChunk { public: - TupleDataChunk(); + explicit TupleDataChunk(mutex &lock_p); //! Disable copy constructors TupleDataChunk(const TupleDataChunk &other) = delete; @@ -124,7 +125,7 @@ struct TupleDataChunk { TupleDataChunk &operator=(TupleDataChunk &&) noexcept; //! Add a part to this chunk - TupleDataChunkPart &AddPart(TupleDataSegment &segment, TupleDataChunkPart &&part); + TupleDataChunkPart &AddPart(TupleDataSegment &segment, unsafe_arena_ptr part_ptr); //! Tries to merge the last chunk part into the second-to-last one void MergeLastChunkPart(TupleDataSegment &segment); //! Verify counts of the parts in this chunk @@ -141,7 +142,7 @@ struct TupleDataChunk { //! Tuple count for this chunk idx_t count; //! Lock for recomputing heap pointers - unsafe_unique_ptr lock; + reference lock; }; struct TupleDataSegment { @@ -171,9 +172,9 @@ struct TupleDataSegment { shared_ptr allocator; const TupleDataLayout &layout; //! The chunks of this segment - unsafe_vector chunks; + unsafe_vector> chunks; //! The chunk parts of this segment - unsafe_vector chunk_parts; + unsafe_vector> chunk_parts; //! The tuple count of this segment idx_t count; //! The data size of this segment @@ -182,9 +183,9 @@ struct TupleDataSegment { //! Lock for modifying pinned_handles mutex pinned_handles_lock; //! Where handles to row blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_row_handles; + unsafe_arena_vector pinned_row_handles; //! Where handles to heap blocks will be stored with TupleDataPinProperties::KEEP_EVERYTHING_PINNED - unsafe_vector pinned_heap_handles; + unsafe_arena_vector pinned_heap_handles; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp index bf22cac33..bbe487d7c 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/types.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/types/vector_cache.hpp" +#include "duckdb/common/sorting/sort_key.hpp" namespace duckdb { @@ -119,13 +120,21 @@ struct TupleDataChunkState { Vector heap_locations = Vector(LogicalType::POINTER); Vector heap_sizes = Vector(LogicalType::UBIGINT); + optional_ptr chunk_lock; + SelectionVector utility = SelectionVector(STANDARD_VECTOR_SIZE); vector> cached_cast_vectors; vector> cached_cast_vector_cache; - //! Cached vector (for InitializeChunkState) - unsafe_vector> parts; + //! Re-usable arrays used while building buffer space + unsafe_vector> chunk_parts; + unsafe_vector> chunk_part_indices; +}; + +struct SortKeyPayloadState { + TupleDataChunkState &sort_key_chunk_state; + SortKeyType sort_key_type; }; struct TupleDataAppendState { diff --git a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp index ceb5637ac..5575e5a08 100644 --- a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp @@ -108,6 +108,7 @@ struct SelectionVector { return selection_data; } buffer_ptr Slice(const SelectionVector &sel, idx_t count) const; + idx_t SliceInPlace(const SelectionVector &sel, idx_t count); string ToString(idx_t count = 0) const; void Print(idx_t count = 0) const; diff --git a/src/duckdb/src/include/duckdb/common/types/string_type.hpp b/src/duckdb/src/include/duckdb/common/types/string_type.hpp index 59bb3c293..2d7bbcf56 100644 --- a/src/duckdb/src/include/duckdb/common/types/string_type.hpp +++ b/src/duckdb/src/include/duckdb/common/types/string_type.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/assert.hpp" +#include "duckdb/common/bswap.hpp" #include "duckdb/common/constants.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/numeric_utils.hpp" @@ -194,22 +195,14 @@ struct string_t { uint32_t a_prefix = Load(const_data_ptr_cast(left.GetPrefix())); uint32_t b_prefix = Load(const_data_ptr_cast(right.GetPrefix())); - // Utility to move 0xa1b2c3d4 into 0xd4c3b2a1, basically inverting the order byte-a-byte - auto byte_swap = [](uint32_t v) -> uint32_t { - uint32_t t1 = (v >> 16u) | (v << 16u); - uint32_t t2 = t1 & 0x00ff00ff; - uint32_t t3 = t1 & 0xff00ff00; - return (t2 << 8u) | (t3 >> 8u); - }; - // Check on prefix ----- - // We dont' need to mask since: + // We don't need to mask since: // if the prefix is greater(after bswap), it will stay greater regardless of the extra bytes // if the prefix is smaller(after bswap), it will stay smaller regardless of the extra bytes // if the prefix is equal, the extra bytes are guaranteed to be /0 for the shorter one if (a_prefix != b_prefix) { - return byte_swap(a_prefix) > byte_swap(b_prefix); + return BSwapIfLE(a_prefix) > BSwapIfLE(b_prefix); } #endif auto memcmp_res = memcmp(left.GetData(), right.GetData(), min_length); diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index 1993d0295..bba9a7297 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -201,6 +201,8 @@ class Value { DUCKDB_API static Value BIGNUM(const_data_ptr_t data, idx_t len); DUCKDB_API static Value BIGNUM(const string &data); + DUCKDB_API static Value GEOMETRY(const_data_ptr_t data, idx_t len); + //! Creates an aggregate state DUCKDB_API static Value AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index cc8a9ffa6..280c9e695 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -1,8 +1,9 @@ #pragma once #include "duckdb/common/typedefs.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/string_type.hpp" namespace duckdb_yyjson { struct yyjson_mut_doc; @@ -10,6 +11,11 @@ struct yyjson_mut_val; } // namespace duckdb_yyjson namespace duckdb { +class Vector; +struct ValidityMask; +struct UnifiedVariantVector; +struct RecursiveUnifiedVectorFormat; +struct UnifiedVectorFormat; enum class VariantChildLookupMode : uint8_t { BY_KEY, BY_INDEX }; @@ -29,24 +35,23 @@ struct VariantNestedData { }; struct VariantDecimalData { +public: + VariantDecimalData(uint32_t width, uint32_t scale, const_data_ptr_t value_ptr) + : width(width), scale(scale), value_ptr(value_ptr) { + } + +public: + PhysicalType GetPhysicalType() const; + +public: uint32_t width; uint32_t scale; + const_data_ptr_t value_ptr = nullptr; }; struct VariantVectorData { public: - explicit VariantVectorData(Vector &variant) - : variant(variant), keys_index_validity(FlatVector::Validity(VariantVector::GetChildrenKeysIndex(variant))), - keys(VariantVector::GetKeys(variant)) { - blob_data = FlatVector::GetData(VariantVector::GetData(variant)); - type_ids_data = FlatVector::GetData(VariantVector::GetValuesTypeId(variant)); - byte_offset_data = FlatVector::GetData(VariantVector::GetValuesByteOffset(variant)); - keys_index_data = FlatVector::GetData(VariantVector::GetChildrenKeysIndex(variant)); - values_index_data = FlatVector::GetData(VariantVector::GetChildrenValuesIndex(variant)); - values_data = FlatVector::GetData(VariantVector::GetValues(variant)); - children_data = FlatVector::GetData(VariantVector::GetChildren(variant)); - keys_data = FlatVector::GetData(keys); - } + explicit VariantVectorData(Vector &variant); public: Vector &variant; @@ -105,68 +110,25 @@ enum class VariantLogicalType : uint8_t { ARRAY = 30, BIGNUM = 31, BITSTRING = 32, + GEOMETRY = 33, ENUM_SIZE /* always kept as last item of the enum */ }; struct UnifiedVariantVectorData { public: - explicit UnifiedVariantVectorData(const RecursiveUnifiedVectorFormat &variant) - : variant(variant), keys(UnifiedVariantVector::GetKeys(variant)), - keys_entry(UnifiedVariantVector::GetKeysEntry(variant)), children(UnifiedVariantVector::GetChildren(variant)), - keys_index(UnifiedVariantVector::GetChildrenKeysIndex(variant)), - values_index(UnifiedVariantVector::GetChildrenValuesIndex(variant)), - values(UnifiedVariantVector::GetValues(variant)), type_id(UnifiedVariantVector::GetValuesTypeId(variant)), - byte_offset(UnifiedVariantVector::GetValuesByteOffset(variant)), data(UnifiedVariantVector::GetData(variant)), - keys_index_validity(keys_index.validity) { - blob_data = data.GetData(); - type_id_data = type_id.GetData(); - byte_offset_data = byte_offset.GetData(); - keys_index_data = keys_index.GetData(); - values_index_data = values_index.GetData(); - values_data = values.GetData(); - children_data = children.GetData(); - keys_data = keys.GetData(); - keys_entry_data = keys_entry.GetData(); - } + explicit UnifiedVariantVectorData(const RecursiveUnifiedVectorFormat &variant); public: - bool RowIsValid(idx_t row) const { - return variant.unified.validity.RowIsValid(variant.unified.sel->get_index(row)); - } - bool KeysIndexIsValid(idx_t row, idx_t index) const { - auto list_entry = GetChildrenListEntry(row); - return keys_index_validity.RowIsValid(keys_index.sel->get_index(list_entry.offset + index)); - } - - list_entry_t GetChildrenListEntry(idx_t row) const { - return children_data[children.sel->get_index(row)]; - } - list_entry_t GetValuesListEntry(idx_t row) const { - return values_data[values.sel->get_index(row)]; - } - const string_t &GetKey(idx_t row, idx_t index) const { - auto list_entry = keys_data[keys.sel->get_index(row)]; - return keys_entry_data[keys_entry.sel->get_index(list_entry.offset + index)]; - } - uint32_t GetKeysIndex(idx_t row, idx_t child_index) const { - auto list_entry = GetChildrenListEntry(row); - return keys_index_data[keys_index.sel->get_index(list_entry.offset + child_index)]; - } - uint32_t GetValuesIndex(idx_t row, idx_t child_index) const { - auto list_entry = GetChildrenListEntry(row); - return values_index_data[values_index.sel->get_index(list_entry.offset + child_index)]; - } - VariantLogicalType GetTypeId(idx_t row, idx_t value_index) const { - auto list_entry = values_data[values.sel->get_index(row)]; - return static_cast(type_id_data[type_id.sel->get_index(list_entry.offset + value_index)]); - } - uint32_t GetByteOffset(idx_t row, idx_t value_index) const { - auto list_entry = values_data[values.sel->get_index(row)]; - return byte_offset_data[byte_offset.sel->get_index(list_entry.offset + value_index)]; - } - const string_t &GetData(idx_t row) const { - return blob_data[data.sel->get_index(row)]; - } + bool RowIsValid(idx_t row) const; + bool KeysIndexIsValid(idx_t row, idx_t index) const; + list_entry_t GetChildrenListEntry(idx_t row) const; + list_entry_t GetValuesListEntry(idx_t row) const; + const string_t &GetKey(idx_t row, idx_t index) const; + uint32_t GetKeysIndex(idx_t row, idx_t child_index) const; + uint32_t GetValuesIndex(idx_t row, idx_t child_index) const; + VariantLogicalType GetTypeId(idx_t row, idx_t value_index) const; + uint32_t GetByteOffset(idx_t row, idx_t value_index) const; + const string_t &GetData(idx_t row) const; public: const RecursiveUnifiedVectorFormat &variant; diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp b/src/duckdb/src/include/duckdb/common/types/variant_value.hpp similarity index 83% rename from src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp rename to src/duckdb/src/include/duckdb/common/types/variant_value.hpp index a4c38ede7..735f06e4b 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant_value.hpp @@ -4,9 +4,10 @@ #include "duckdb/common/vector.hpp" #include "duckdb/common/types/value.hpp" -#include "yyjson.hpp" - -using namespace duckdb_yyjson; +namespace duckdb_yyjson { +struct yyjson_mut_doc; +struct yyjson_mut_val; +} // namespace duckdb_yyjson namespace duckdb { @@ -41,7 +42,8 @@ struct VariantValue { void AddItem(VariantValue &&val); public: - yyjson_mut_val *ToJSON(ClientContext &context, yyjson_mut_doc *doc) const; + duckdb_yyjson::yyjson_mut_val *ToJSON(ClientContext &context, duckdb_yyjson::yyjson_mut_doc *doc) const; + static void ToVARIANT(vector &input, Vector &result); public: VariantValueType value_type; diff --git a/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp b/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp new file mode 100644 index 000000000..950980aef --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/variant_visitor.hpp @@ -0,0 +1,232 @@ +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/enum_util.hpp" + +#include + +namespace duckdb { + +template +class VariantVisitor { + // Detects if T has a static VisitMetadata with signature + // void VisitMetadata(VariantLogicalType, Args...) + template + class has_visit_metadata { + private: + template + static auto test(int) -> decltype(U::VisitMetadata(std::declval(), std::declval()...), + std::true_type {}); + + template + static std::false_type test(...); + + public: + static constexpr bool value = decltype(test(0))::value; + }; + +public: + template + static ReturnType Visit(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, Args &&...args) { + if (!variant.RowIsValid(row)) { + return Visitor::VisitNull(std::forward(args)...); + } + + auto type_id = variant.GetTypeId(row, values_idx); + auto byte_offset = variant.GetByteOffset(row, values_idx); + auto blob_data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = const_data_ptr_cast(blob_data + byte_offset); + + VisitMetadata(type_id, std::forward(args)...); + + switch (type_id) { + case VariantLogicalType::VARIANT_NULL: + return Visitor::VisitNull(std::forward(args)...); + case VariantLogicalType::BOOL_TRUE: + return Visitor::VisitBoolean(true, std::forward(args)...); + case VariantLogicalType::BOOL_FALSE: + return Visitor::VisitBoolean(false, std::forward(args)...); + case VariantLogicalType::INT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::INT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT8: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT16: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT32: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT64: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::UINT128: + return Visitor::template VisitInteger(Load(ptr), std::forward(args)...); + case VariantLogicalType::FLOAT: + return Visitor::VisitFloat(Load(ptr), std::forward(args)...); + case VariantLogicalType::DOUBLE: + return Visitor::VisitDouble(Load(ptr), std::forward(args)...); + case VariantLogicalType::UUID: + return Visitor::VisitUUID(Load(ptr), std::forward(args)...); + case VariantLogicalType::DATE: + return Visitor::VisitDate(date_t(Load(ptr)), std::forward(args)...); + case VariantLogicalType::INTERVAL: + return Visitor::VisitInterval(Load(ptr), std::forward(args)...); + case VariantLogicalType::VARCHAR: + case VariantLogicalType::BLOB: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::GEOMETRY: + return VisitString(type_id, variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::DECIMAL: + return VisitDecimal(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::ARRAY: + return VisitArray(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::OBJECT: + return VisitObject(variant, row, values_idx, std::forward(args)...); + case VariantLogicalType::TIME_MICROS: + return Visitor::VisitTime(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_NANOS: + return Visitor::VisitTimeNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIME_MICROS_TZ: + return Visitor::VisitTimeTZ(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_SEC: + return Visitor::VisitTimestampSec(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MILIS: + return Visitor::VisitTimestampMs(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS: + return Visitor::VisitTimestamp(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_NANOS: + return Visitor::VisitTimestampNanos(Load(ptr), std::forward(args)...); + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + return Visitor::VisitTimestampTZ(Load(ptr), std::forward(args)...); + default: + return Visitor::VisitDefault(type_id, ptr, std::forward(args)...); + } + } + + // Non-void version + template + static typename std::enable_if::value, vector>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + vector array_items; + array_items.reserve(array_data.child_count); + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + array_items.emplace_back(Visit(variant, row, values_index, std::forward(args)...)); + } + return array_items; + } + + // Void version + template + static typename std::enable_if::value, void>::type + VisitArrayItems(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &array_data, + Args &&...args) { + for (idx_t i = 0; i < array_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, array_data.children_idx + i); + Visit(variant, row, values_index, std::forward(args)...); + } + } + + template + static child_list_t VisitObjectItems(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &object_data, Args &&...args) { + child_list_t object_items; + for (idx_t i = 0; i < object_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, object_data.children_idx + i); + auto val = Visit(variant, row, values_index, std::forward(args)...); + + auto keys_index = variant.GetKeysIndex(row, object_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + + object_items.emplace_back(key.GetString(), std::move(val)); + } + return object_items; + } + +private: + template + static typename std::enable_if::value, void>::type + VisitMetadata(VariantLogicalType type_id, Args &&...args) { + Visitor::VisitMetadata(type_id, std::forward(args)...); + } + + // Fallback if the method does not exist + template + static typename std::enable_if::value, void>::type VisitMetadata(VariantLogicalType, + Args &&...) { + // do nothing + } + + template + static ReturnType VisitArray(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitArray(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitObject(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_nested_data = VariantUtils::DecodeNestedData(variant, row, values_idx); + return Visitor::VisitObject(variant, row, decoded_nested_data, std::forward(args)...); + } + + template + static ReturnType VisitString(VariantLogicalType type_id, const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx, Args &&...args) { + auto decoded_string = VariantUtils::DecodeStringData(variant, row, values_idx); + if (type_id == VariantLogicalType::VARCHAR) { + return Visitor::VisitString(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BLOB) { + return Visitor::VisitBlob(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BIGNUM) { + return Visitor::VisitBignum(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::GEOMETRY) { + return Visitor::VisitGeometry(decoded_string, std::forward(args)...); + } + if (type_id == VariantLogicalType::BITSTRING) { + return Visitor::VisitBitstring(decoded_string, std::forward(args)...); + } + throw InternalException("String-backed variant type (%s) not handled", EnumUtil::ToString(type_id)); + } + + template + static ReturnType VisitDecimal(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx, + Args &&...args) { + auto decoded_decimal = VariantUtils::DecodeDecimalData(variant, row, values_idx); + auto &width = decoded_decimal.width; + auto &scale = decoded_decimal.scale; + auto &ptr = decoded_decimal.value_ptr; + if (width > DecimalWidth::max) { + throw InternalException("Can't handle decimal of width: %d", width); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else if (width > DecimalWidth::max) { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } else { + return Visitor::template VisitDecimal(Load(ptr), width, scale, + std::forward(args)...); + } + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/vector.hpp b/src/duckdb/src/include/duckdb/common/types/vector.hpp index 1ab48c056..890118013 100644 --- a/src/duckdb/src/include/duckdb/common/types/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/vector.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/bitset.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/mutex.hpp" #include "duckdb/common/types/selection_vector.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/value.hpp" @@ -21,6 +22,7 @@ namespace duckdb { class VectorCache; +class VectorChildBuffer; class VectorStringBuffer; class VectorStructBuffer; class VectorListBuffer; @@ -195,6 +197,8 @@ class Vector { DUCKDB_API void Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t count); //! Creates a reference to a dictionary of the other vector DUCKDB_API void Dictionary(Vector &dict, idx_t dictionary_size, const SelectionVector &sel, idx_t count); + //! Creates a dictionary on the reusable dict + DUCKDB_API void Dictionary(buffer_ptr reusable_dict, const SelectionVector &sel); //! Creates the data of this vector with the specified type. Any data that //! is currently in the vector is destroyed. @@ -306,20 +310,24 @@ class Vector { //! The buffer holding auxiliary data of the vector //! e.g. a string vector uses this to store strings buffer_ptr auxiliary; - //! The buffer holding precomputed hashes of the data in the vector - //! used for caching hashes of string dictionaries - buffer_ptr cached_hashes; }; -//! The DictionaryBuffer holds a selection vector +//! The VectorChildBuffer holds a child Vector class VectorChildBuffer : public VectorBuffer { public: explicit VectorChildBuffer(Vector vector) - : VectorBuffer(VectorBufferType::VECTOR_CHILD_BUFFER), data(std::move(vector)) { + : VectorBuffer(VectorBufferType::VECTOR_CHILD_BUFFER), data(std::move(vector)), + cached_hashes(LogicalType::HASH, nullptr) { } public: Vector data; + //! Optional size/id to uniquely identify re-occurring dictionaries + optional_idx size; + string id; + //! For caching the hashes of a child buffer + mutex cached_hashes_lock; + Vector cached_hashes; }; struct ConstantVector { @@ -409,22 +417,27 @@ struct DictionaryVector { } static inline optional_idx DictionarySize(const Vector &vector) { VerifyDictionary(vector); + const auto &child_buffer = vector.auxiliary->Cast(); + if (child_buffer.size.IsValid()) { + return child_buffer.size; + } return vector.buffer->Cast().GetDictionarySize(); } static inline const string &DictionaryId(const Vector &vector) { VerifyDictionary(vector); + const auto &child_buffer = vector.auxiliary->Cast(); + if (!child_buffer.id.empty()) { + return child_buffer.id; + } return vector.buffer->Cast().GetDictionaryId(); } - static inline void SetDictionaryId(Vector &vector, string new_id) { - VerifyDictionary(vector); - vector.buffer->Cast().SetDictionaryId(std::move(new_id)); - } static inline bool CanCacheHashes(const LogicalType &type) { return type.InternalType() == PhysicalType::VARCHAR; } static inline bool CanCacheHashes(const Vector &vector) { return DictionarySize(vector).IsValid() && CanCacheHashes(vector.GetType()); } + static buffer_ptr CreateReusableDictionary(const LogicalType &type, const idx_t &size); static const Vector &GetCachedHashes(Vector &input); }; @@ -488,6 +501,13 @@ struct FlatVector { }; struct ListVector { + static inline const list_entry_t *GetData(const Vector &v) { + if (v.GetVectorType() == VectorType::DICTIONARY_VECTOR) { + auto &child = DictionaryVector::Child(v); + return GetData(child); + } + return FlatVector::GetData(v); + } static inline list_entry_t *GetData(Vector &v) { if (v.GetVectorType() == VectorType::DICTIONARY_VECTOR) { auto &child = DictionaryVector::Child(v); diff --git a/src/duckdb/src/include/duckdb/common/unique_ptr.hpp b/src/duckdb/src/include/duckdb/common/unique_ptr.hpp index f5d0972d7..f4bdc7b63 100644 --- a/src/duckdb/src/include/duckdb/common/unique_ptr.hpp +++ b/src/duckdb/src/include/duckdb/common/unique_ptr.hpp @@ -1,3 +1,4 @@ + #pragma once #include "duckdb/common/exception.hpp" diff --git a/src/duckdb/src/include/duckdb/common/vector.hpp b/src/duckdb/src/include/duckdb/common/vector.hpp index 676adac20..3035edcd5 100644 --- a/src/duckdb/src/include/duckdb/common/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector.hpp @@ -17,14 +17,23 @@ namespace duckdb { -template -class vector : public std::vector> { // NOLINT: matching name of std +template > +class vector : public std::vector { // NOLINT: matching name of std public: - using original = std::vector>; + using original = std::vector; using original::original; + using value_type = typename original::value_type; + using allocator_type = typename original::allocator_type; using size_type = typename original::size_type; - using const_reference = typename original::const_reference; + using difference_type = typename original::difference_type; using reference = typename original::reference; + using const_reference = typename original::const_reference; + using pointer = typename original::pointer; + using const_pointer = typename original::const_pointer; + using iterator = typename original::iterator; + using const_iterator = typename original::const_iterator; + using reverse_iterator = typename original::reverse_iterator; + using const_reverse_iterator = typename original::const_reverse_iterator; private: static inline void AssertIndexInBounds(idx_t index, idx_t size) { diff --git a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp index 6a0f0346a..b35e7bb7d 100644 --- a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/file_system.hpp" #include "duckdb/common/map.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/extension_helper.hpp" namespace duckdb { @@ -29,6 +30,7 @@ class VirtualFileSystem : public FileSystem { timestamp_t GetLastModifiedTime(FileHandle &handle) override; string GetVersionTag(FileHandle &handle) override; FileType GetFileType(FileHandle &handle) override; + FileMetadata Stats(FileHandle &handle) override; void Truncate(FileHandle &handle, int64_t new_size) override; @@ -64,6 +66,7 @@ class VirtualFileSystem : public FileSystem { void SetDisabledFileSystems(const vector &names) override; bool SubSystemIsDisabled(const string &name) override; + bool IsDisabledForPath(const string &path) override; string PathSeparator(const string &path) override; @@ -82,8 +85,10 @@ class VirtualFileSystem : public FileSystem { } private: + FileSystem &FindFileSystem(const string &path, optional_ptr file_opener); + FileSystem &FindFileSystem(const string &path, optional_ptr database_instance); FileSystem &FindFileSystem(const string &path); - FileSystem &FindFileSystemInternal(const string &path); + optional_ptr FindFileSystemInternal(const string &path); private: vector> sub_systems; diff --git a/src/duckdb/src/include/duckdb/common/weak_ptr_ipp.hpp b/src/duckdb/src/include/duckdb/common/weak_ptr_ipp.hpp index 076fde953..1a3d4990b 100644 --- a/src/duckdb/src/include/duckdb/common/weak_ptr_ipp.hpp +++ b/src/duckdb/src/include/duckdb/common/weak_ptr_ipp.hpp @@ -1,3 +1,7 @@ +#pragma once + +#include "duckdb/common/shared_ptr_ipp.hpp" + namespace duckdb { template @@ -114,4 +118,7 @@ class weak_ptr { // NOLINT: invalid case style } }; +template +using unsafe_weak_ptr = weak_ptr; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp index 8f0b77ccf..112e76109 100644 --- a/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp +++ b/src/duckdb/src/include/duckdb/execution/expression_executor_state.hpp @@ -75,12 +75,10 @@ struct ExecuteFunctionState : public ExpressionState { //! Only valid when the expression is eligible for the dictionary expression optimization //! This is the case when the input is "practically unary", i.e., only one non-const input column optional_idx input_col_idx; - //! Storage ID of the input dictionary vector - string current_input_dictionary_id; //! Vector holding the expression executed on the entire dictionary - unique_ptr output_dictionary; - //! ID of the output dictionary_vector - string output_dictionary_id; + buffer_ptr output_dictionary; + //! ID of the input dictionary Vector + string current_input_dictionary_id; }; struct ExpressionExecutorState { diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp index 71e64cfe7..eef273ba7 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -110,19 +110,26 @@ class ART : public BoundIndex { //! Returns the in-memory usage of the ART. idx_t GetInMemorySize(IndexLock &index_lock) override; + bool RequiresTransactionality() const override; + unique_ptr CreateEmptyCopy(const string &name_prefix, + IndexConstraintType constraint_type) const override; + //! ART key generation. template void GenerateKeys(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys); void GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, unsafe_vector &row_id_keys); - //! Verifies the nodes and optionally returns a string of the ART. - string VerifyAndToString(IndexLock &l, const bool only_verify) override; + //! Verifies the nodes. + void Verify(IndexLock &l) override; //! Verifies that the node allocations match the node counts. void VerifyAllocations(IndexLock &l) override; //! Verifies the index buffers. void VerifyBuffers(IndexLock &l) override; + //! Returns string representation of the ART. + string ToString(IndexLock &l, bool display_ascii = false) override; + private: bool SearchEqual(ARTKey &key, idx_t max_count, set &row_ids); bool SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &row_ids); @@ -151,7 +158,8 @@ class ART : public BoundIndex { void WritePartialBlocks(QueryContext context, const bool v1_0_0_storage); void SetPrefixCount(const IndexStorageInfo &info); - string VerifyAndToStringInternal(const bool only_verify); + string ToStringInternal(bool display_ascii); + void VerifyInternal(); void VerifyAllocationsInternal(); }; diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp index 62903b198..3ed7f46ff 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp @@ -62,6 +62,60 @@ class ARTOperator { return nullptr; } + //! LookupInLeaf returns true if the rowid is in the leaf: + //! 1) If the leaf is an inlined leaf, check if the rowid matches. + //! 2) If the leaf is a gate node, perform a search in the nested ART for the rowid. + static bool LookupInLeaf(ART &art, const Node &node, const ARTKey &rowid) { + reference ref(node); + idx_t depth = 0; + + while (ref.get().HasMetadata()) { + const auto type = ref.get().GetType(); + switch (type) { + case NType::LEAF_INLINED: { + return ref.get().GetRowId() == rowid.GetRowId(); + } + case NType::LEAF: { + throw InternalException("Invalid node type (LEAF) for ARTOperator::NestedLookup."); + } + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: { + D_ASSERT(depth + 1 == Prefix::ROW_ID_SIZE); + const auto byte = rowid[Prefix::ROW_ID_COUNT]; + return ref.get().HasByte(art, byte); + } + case NType::NODE_4: + case NType::NODE_16: + case NType::NODE_48: + case NType::NODE_256: { + D_ASSERT(depth < Prefix::ROW_ID_SIZE); + auto child = ref.get().GetChild(art, rowid[depth]); + if (child) { + // Continue in the child. + ref = *child; + depth++; + D_ASSERT(ref.get().HasMetadata()); + continue; + } + return false; + } + case NType::PREFIX: { + Prefix prefix(art, ref.get()); + for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { + if (prefix.data[i] != rowid[depth]) { + // The key and the prefix don't match. + return false; + } + depth++; + } + ref = *prefix.ptr; + } + } + } + return false; + } + //! Insert a key and its row ID into the node. //! Starts at depth (in the key). //! status indicates if the insert happens inside a gate or not. @@ -202,6 +256,8 @@ class ARTOperator { if (parent.get().GetType() == NType::PREFIX) { // We might have to compress: // PREFIX (greatgrandparent) - Node4 (grandparent) - PREFIX - INLINED_LEAF. + // The parent does not have to be passed in, as it is a child of the possibly being compressed N4. + // Then, when we delete that child, we also free it. Node::DeleteChild(art, grandparent, greatgrandparent, current_key.get()[grandparent_depth], status, row_id); return; @@ -336,7 +392,6 @@ class ARTOperator { static void InsertIntoPrefix(ART &art, reference &node_ref, const ARTKey &key, const ARTKey &row_id, const idx_t pos, const idx_t depth, const GateStatus status) { - const auto cast_pos = UnsafeNumericCast(pos); const auto byte = Prefix::GetByte(art, node_ref, cast_pos); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp index 209d022dc..797c18469 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/base_leaf.hpp @@ -31,13 +31,15 @@ class BaseLeaf { public: //! Get a new BaseLeaf and initialize it. - static BaseLeaf &New(ART &art, Node &node) { + static NodeHandle New(ART &art, Node &node) { node = Node::GetAllocator(art, TYPE).New(); node.SetMetadata(static_cast(TYPE)); - auto &n = Node::Ref(art, node, TYPE); + NodeHandle handle(art, node); + auto &n = handle.Get(); + n.count = 0; - return n; + return handle; } //! Returns true, if the byte exists, else false. @@ -70,7 +72,7 @@ class BaseLeaf { private: static void InsertByteInternal(BaseLeaf &n, const uint8_t byte); - static BaseLeaf &DeleteByteInternal(ART &art, Node &node, const uint8_t byte); + static NodeHandle DeleteByteInternal(ART &art, Node &node, const uint8_t byte); }; //! Node7Leaf holds up to seven sorted bytes. diff --git a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp index c5907f820..793dcf40b 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp @@ -70,7 +70,7 @@ class Iterator { void FindMinimum(const Node &node); //! Finds the lower bound of the ART and adds the nodes to the stack. Returns false, if the lower //! bound exceeds the maximum value of the ART. - bool LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth); + bool LowerBound(const Node &node, const ARTKey &key, const bool equal); //! Returns the nested depth. uint8_t GetNestedDepth() const { diff --git a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp index 30efdba0a..481e10521 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp @@ -57,11 +57,15 @@ class Leaf { static bool DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, const idx_t max_count); //! Vacuums the linked list of leaves. static void DeprecatedVacuum(ART &art, Node &node); - //! Returns the string representation of the linked list of leaves, if only_verify is true. - //! Else, it traverses and verifies the linked list of leaves. - static string DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify); + + //! Traverses and verifies the linked list of leaves. + static void DeprecatedVerify(ART &art, const Node &node); //! Count the number of leaves. void DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const; + + //! Return string representation of the linked list of leaves. + //! If print_deprecated_leaves is false, returns "[deprecated leaves]" with proper indentation. + static string DeprecatedToString(ART &art, const Node &node, const ToStringOptions &options); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp index ae00d3e6d..1fa9ade88 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp @@ -39,6 +39,43 @@ class ART; class Prefix; class ARTKey; +//! Options for ToString printing functions +struct ToStringOptions { + // Indentation for root node. + idx_t indent_level = 0; + // Amount to increase idnentation when traversing to a child node. + idx_t indent_amount = 4; + bool inside_gate = false; + bool display_ascii = false; + // Optional key argument to only print the path along to a specific key. + // This prints nodes along the path, as well as the child bytes, but doesn't traverse into children not on the path + // to the optional key_path. + // This works in conjunction with the depth_remaining and structure_only arguments. + // Note that nested ARTs are printed in their entirety regardless. + optional_ptr key_path = nullptr; + idx_t key_depth = 0; + // If we have a key_path argument, we only print along a certain path to a specified key. depth_remaining allows us + // to short circuit that, and print the entire tree starting at a certain depth. So if we are traversing towards + // the leaf for a key, we can start printing the entire tree again. This is useful to be able to see a region of the + // ART around a specific leaf. + idx_t depth_remaining = 0; + bool print_deprecated_leaves = true; + // Similar to key path, but don't print the other child bytes at each node along the path to the key, i.e. skip + // printing node contents. This gives a very barebones skeleton of the node structure leading to a key, and this + // can also be short circuited by depth_remaining. + bool structure_only = false; + + ToStringOptions() = default; + + ToStringOptions(idx_t indent_level, bool inside_gate, bool display_ascii, optional_ptr key_path, + idx_t key_depth, idx_t depth_remaining, bool print_deprecated_leaves, bool structure_only, + idx_t indent_amount = 2) + : indent_level(indent_level), indent_amount(indent_amount), inside_gate(inside_gate), + display_ascii(display_ascii), key_path(key_path), key_depth(key_depth), depth_remaining(depth_remaining), + print_deprecated_leaves(print_deprecated_leaves), structure_only(structure_only) { + } +}; + //! The Node is the pointer class of the ART index. //! It inherits from the IndexPointer, and adds ART-specific functionality. class Node : public IndexPointer { @@ -94,9 +131,8 @@ class Node : public IndexPointer { //! Get the first byte greater than or equal to the byte. bool GetNextByte(ART &art, uint8_t &byte) const; - //! Returns the string representation of the node, if only_verify is false. - //! Else, it traverses and verifies the node. - string VerifyAndToString(ART &art, const bool only_verify) const; + //! Traverses and verifies the node. + void Verify(ART &art) const; //! Counts each node type. void VerifyAllocations(ART &art, unordered_map &node_counts) const; @@ -107,6 +143,13 @@ class Node : public IndexPointer { static void TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &deprecated_prefix_allocator); + //! Returns the string representation of the node at indentation level. + //! + //! Parameters: + //! - art: root node of tree being printed. + //! - options: Printing options (see ToStringOptions struct for details). + string ToString(ART &art, const ToStringOptions &options) const; + //! Returns the node type. inline NType GetType() const { return NType(GetMetadata() & ~AND_GATE); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp index 835e32c0f..902ba0ad9 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp @@ -48,7 +48,7 @@ class Prefix { //! Concatenates parent -> prev_node4 -> child. static void Concat(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte, - const GateStatus node4_status); + const GateStatus node4_status, const GateStatus status); //! Removes up to pos bytes from the prefix. //! Shifts all subsequent bytes by pos. Frees empty nodes. @@ -61,18 +61,21 @@ class Prefix { //! after its creation. static GateStatus Split(ART &art, reference &node, Node &child, const uint8_t pos); - //! Returns the string representation of the node, or only traverses and verifies the node and its subtree - static string VerifyAndToString(ART &art, const Node &node, const bool only_verify); + //! Traverses and verifies the node and its subtree + static void Verify(ART &art, const Node &node); //! Transform the child of the node. static void TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator); + //! Returns the string representation of the node using ToStringOptions. + static string ToString(ART &art, const Node &node, const ToStringOptions &options); + private: static Prefix NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset); static Prefix GetTail(ART &art, const Node &node); static void ConcatInternal(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte, - const bool inside_gate); + const GateStatus status); static void ConcatNode4WasGate(ART &art, Node &node4, const Node child, uint8_t byte); static void ConcatChildIsGate(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte); static void ConcatOutsideGate(ART &art, Node &parent, Node &node4, const Node child, uint8_t byte); diff --git a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp index 914288bfa..ae6daa0cd 100644 --- a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/common/enums/index_constraint_type.hpp" #include "duckdb/common/types/constraint_conflict_info.hpp" #include "duckdb/common/types/data_chunk.hpp" @@ -60,6 +61,16 @@ class BoundIndex : public Index { //! The index constraint type IndexConstraintType index_constraint_type; + //! The vector of unbound expressions, which are later turned into bound expressions. + //! We need to store the unbound expressions, as we might not always have the context + //! available to bind directly. + //! The leaves of these unbound expressions are BoundColumnRefExpressions. + //! These BoundColumnRefExpressions contain a binding (ColumnBinding), + //! and that contains a table_index and a column_index. + //! The table_index is a dummy placeholder. + //! The column_index indexes the column_ids vector in the Index base class. + //! Those column_ids store the physical table indexes of the Index, + //! and we use them when binding the unbound expressions. vector> unbound_expressions; public: @@ -119,15 +130,27 @@ class BoundIndex : public Index { //! Obtains a lock and calls Vacuum while holding that lock. void Vacuum(); + //! Whether or not the index requires transactionality. If true we will create delta indexes + virtual bool RequiresTransactionality() const; + //! Creates an empty copy of the index with the same schema, etc, but a different constraint type + //! This will only be called if RequiresTransactionality returns true + virtual unique_ptr CreateEmptyCopy(const string &name_prefix, + IndexConstraintType constraint_type) const; + //! Returns the in-memory usage of the index. The lock obtained from InitializeLock must be held virtual idx_t GetInMemorySize(IndexLock &state) = 0; //! Returns the in-memory usage of the index idx_t GetInMemorySize(); //! Returns the string representation of an index, or only traverses and verifies the index. - virtual string VerifyAndToString(IndexLock &l, const bool only_verify) = 0; + virtual void Verify(IndexLock &l) = 0; //! Obtains a lock and calls VerifyAndToString. - string VerifyAndToString(const bool only_verify); + void Verify(); + + //! Returns the string representation of an index. + virtual string ToString(IndexLock &l, bool display_ascii = false) = 0; + //! Obtains a lock and calls ToString. + string ToString(bool display_ascii = false); //! Ensures that the node allocation counts match the node counts. virtual void VerifyAllocations(IndexLock &l) = 0; @@ -155,14 +178,22 @@ class BoundIndex : public Index { virtual string GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) = 0; - void ApplyBufferedAppends(const vector &table_types, ColumnDataCollection &buffered_appends, + //! Replay index insert and delete operations buffered during WAL replay. + //! table_types has the physical types of the table in the order they appear, not logical (no generated columns). + //! mapped_column_ids contains the sorted order of Indexed physical column ID's (see unbound_index.hpp comments). + void ApplyBufferedReplays(const vector &table_types, BufferedIndexReplays &buffered_replays, const vector &mapped_column_ids); protected: //! Lock used for any changes to the index mutex lock; - //! Bound expressions used during expression execution + //! The vector of bound expressions to generate the Index keys based on a data chunk. + //! The leaves of the bound expressions are BoundReferenceExpressions. + //! These BoundReferenceExpressions contain offsets into the DataChunk to retrieve the columns + //! for the expression. + //! With these offsets into the DataChunk, the expression executor can now evaluate the expression + //! on incoming data chunks to generate the keys. vector> bound_expressions; private: diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp index 691a4aac6..65ffd167f 100644 --- a/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_allocator.hpp @@ -30,7 +30,8 @@ class FixedSizeAllocator { public: //! Construct a new fixed-size allocator - FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager); + FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager, + MemoryTag memory_tag = MemoryTag::ART_INDEX); //! Block manager of the database instance BlockManager &block_manager; @@ -152,6 +153,8 @@ class FixedSizeAllocator { void VerifyBuffers(); private: + //! Memory tag of memory that is allocated through the allocator + MemoryTag memory_tag; //! Allocation size of one segment in a buffer //! We only need this value to calculate bitmask_count, bitmask_offset, and //! available_segments_per_buffer diff --git a/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp index e7c5b6aa9..6ca7dc1aa 100644 --- a/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/fixed_size_buffer.hpp @@ -43,7 +43,7 @@ class FixedSizeBuffer { public: //! Constructor for a new in-memory buffer - explicit FixedSizeBuffer(BlockManager &block_manager); + explicit FixedSizeBuffer(BlockManager &block_manager, MemoryTag memory_tag); //! Constructor for deserializing buffer metadata from disk FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, const BlockPointer &block_pointer); diff --git a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp index ec2fc3cfd..0ca4aa9d2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp @@ -16,15 +16,61 @@ namespace duckdb { class ColumnDataCollection; +enum class BufferedIndexReplay : uint8_t { INSERT_ENTRY = 0, DEL_ENTRY = 1 }; + +struct ReplayRange { + BufferedIndexReplay type; + // [start, end) - start is inclusive, end is exclusive for the range within the ColumnDataCollection + // buffer for operations to replay for this range. + idx_t start; + idx_t end; + explicit ReplayRange(const BufferedIndexReplay replay_type, const idx_t start_p, const idx_t end_p) + : type(replay_type), start(start_p), end(end_p) { + } +}; + +// All inserts and deletes to be replayed are stored in their respective buffers. +// Since the inserts and deletes may be interleaved, however, ranges stores the ordering of operations +// and their offsets in the respective buffer. +// Simple example: +// ranges[0] - INSERT_ENTRY, [0,6) +// ranges[1] - DEL_ENTRY, [0,3) +// ranges[2] - INSERT_ENTRY [6,12) +// So even though the buffered_inserts has all the insert data from [0,12), ranges gives us the intervals for +// replaying the index operations in the right order. +struct BufferedIndexReplays { + vector ranges; + unique_ptr buffered_inserts; + unique_ptr buffered_deletes; + + BufferedIndexReplays() = default; + + unique_ptr &GetBuffer(const BufferedIndexReplay replay_type) { + if (replay_type == BufferedIndexReplay::INSERT_ENTRY) { + return buffered_inserts; + } + return buffered_deletes; + } + + bool HasBufferedReplays() const { + return !ranges.empty(); + } +}; + class UnboundIndex final : public Index { private: //! The CreateInfo of the index. unique_ptr create_info; //! The serialized storage information of the index. IndexStorageInfo storage_info; - //! Buffer for WAL replay appends. - unique_ptr buffered_appends; - //! Maps the column IDs in the buffered appends to the table columns. + + //! Buffered for index operations during WAL replay. They are replayed upon index binding. + BufferedIndexReplays buffered_replays; + + //! Maps the column IDs in the buffered replays to a physical table offset. + //! For example, column [i] in a buffered ColumnDataCollection is the data for an Indexed column with + //! physical table index mapped_column_ids[i]. + //! This is in sorted order of physical column IDs. vector mapped_column_ids; public: @@ -59,13 +105,19 @@ class UnboundIndex final : public Index { void CommitDrop() override; - void BufferChunk(DataChunk &chunk, Vector &row_ids, const vector &mapped_column_ids_p); - bool HasBufferedAppends() const { - return buffered_appends != nullptr; + //! Buffer Index delete or insert (replay_type) data chunk. + //! See note above on mapped_column_ids, this function assumes that index_column_chunk maps into + //! mapped_column_ids_p to get the physical column index for each Indexed column in the chunk. + void BufferChunk(DataChunk &index_column_chunk, Vector &row_ids, const vector &mapped_column_ids_p, + BufferedIndexReplay replay_type); + bool HasBufferedReplays() const { + return buffered_replays.HasBufferedReplays(); } - ColumnDataCollection &GetBufferedAppends() const { - return *buffered_appends; + + BufferedIndexReplays &GetBufferedReplays() { + return buffered_replays; } + const vector &GetMappedColumnIds() const { return mapped_column_ids; } diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp index 4d0e6ae47..942997dbb 100644 --- a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -18,6 +18,7 @@ #include "duckdb/common/types/vector.hpp" #include "duckdb/execution/aggregate_hashtable.hpp" #include "duckdb/execution/ht_entry.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" namespace duckdb { @@ -214,6 +215,10 @@ class JoinHashTable { TupleDataCollection &GetDataCollection() { return *data_collection; } + //! Perform a full scan of a build column, filling the provided addresses vector and result vector. + //! Returns the number of tuples found (can be smaller than the vector capacity). + idx_t ScanKeyColumn(Vector &addresses, Vector &result, idx_t column_index) const; + bool NullValuesAreEqual(idx_t col_idx) const { return null_values_are_equal[col_idx]; } @@ -279,6 +284,8 @@ class JoinHashTable { bool single_join_error_on_multiple_rows = true; //! Whether or not to perform deduplication based on join_keys when building ht bool insert_duplicate_keys = true; + //! Number of probe matches + atomic total_probe_matches {0}; struct { mutex mj_lock; @@ -333,6 +340,10 @@ class JoinHashTable { //! An empty tuple that's a "dead end", can be used to stop chains early unsafe_unique_array dead_end; + //! Whether or not to use a bloom filter will be determined by the operator + BloomFilter bloom_filter; + bool should_build_bloom_filter = false; + //! Copying not allowed JoinHashTable(const JoinHashTable &) = delete; @@ -397,13 +408,13 @@ class JoinHashTable { static constexpr double DEFAULT_LOAD_FACTOR = 2.0; //! For a LOAD_FACTOR of 1.5, the HT is between 33% and 67% full static constexpr double EXTERNAL_LOAD_FACTOR = 1.5; + //! Minimum capacity of the pointer table + static constexpr idx_t MINIMUM_CAPACITY = 16384; double load_factor = DEFAULT_LOAD_FACTOR; //! Capacity of the pointer table given the ht count idx_t PointerTableCapacity(idx_t count) const { - static constexpr idx_t MINIMUM_CAPACITY = 16384; - const auto capacity = NextPowerOfTwo(LossyNumericCast(static_cast(count) * load_factor)); return MaxValue(capacity, MINIMUM_CAPACITY); } @@ -412,6 +423,14 @@ class JoinHashTable { return PointerTableCapacity(count) * sizeof(data_ptr_t); } + void SetBuildBloomFilter(const bool should_build) { + this->should_build_bloom_filter = should_build; + } + + BloomFilter &GetBloomFilter() { + return bloom_filter; + } + //! Get total size of HT if all partitions would be built idx_t GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, idx_t &max_partition_count) const; diff --git a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp index 8c04ecde0..d17e6944f 100644 --- a/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/execution/merge_sort_tree.hpp @@ -574,7 +574,6 @@ template template void MergeSortTree::AggregateLowerBound(const idx_t lower, const idx_t upper, const E needle, L aggregate) const { - if (lower >= upper) { return; } diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp index 45593a0d7..2745432c5 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp @@ -95,7 +95,8 @@ class PhysicalHashAggregate : public PhysicalOperator { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp index 71a2415e9..4a8b39fe0 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp @@ -37,7 +37,8 @@ class PhysicalPartitionedAggregate : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; bool IsSource() const override { diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp index b30373624..939f743b3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp @@ -34,7 +34,8 @@ class PhysicalPerfectHashAggregate : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp index 88d2a20b6..c350ed80c 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp @@ -35,7 +35,8 @@ class PhysicalUngroupedAggregate : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp index 8629c7070..c35fe02c0 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp @@ -36,7 +36,8 @@ class PhysicalWindow : public PhysicalOperator { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/base_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/base_scanner.hpp index 2a123827c..5e36690ee 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/base_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/base_scanner.hpp @@ -121,6 +121,8 @@ class BaseScanner { virtual ~BaseScanner() = default; + void Print() const; + //! Returns true if the scanner is finished bool FinishedFile() const; @@ -164,10 +166,15 @@ class BaseScanner { //! States CSVStates states; + //! If the scanner ever entered a quoted state bool ever_quoted = false; + //! If the scanner ever entered an escaped state. bool ever_escaped = false; + //! If the scanner ever used advantage of the non-strict mode. + bool used_unstrictness = false; + //! Shared pointer to the buffer_manager, this is shared across multiple scanners shared_ptr buffer_manager; @@ -302,6 +309,9 @@ class BaseScanner { !state_machine->dialect_options.state_machine_options.strict_mode.GetValue())) { // We only set the ever escaped variable if this is either a quote char OR strict mode is off ever_escaped = true; + if (states.states[0] == CSVState::UNQUOTED_ESCAPE) { + used_unstrictness = true; + } } ever_quoted = true; T::SetQuoted(result, iterator.pos.buffer_pos); @@ -332,11 +342,15 @@ class BaseScanner { break; } case CSVState::ESCAPE: - case CSVState::UNQUOTED_ESCAPE: case CSVState::ESCAPED_RETURN: T::SetEscaped(result); iterator.pos.buffer_pos++; break; + case CSVState::UNQUOTED_ESCAPE: + T::SetEscaped(result); + iterator.pos.buffer_pos++; + used_unstrictness = true; + break; case CSVState::STANDARD: { iterator.pos.buffer_pos++; while (iterator.pos.buffer_pos + 8 < to_pos) { diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp index 30ba0abc5..aad90df94 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp @@ -52,21 +52,21 @@ class CSVError { CSVError() {}; CSVError(string error_message, CSVErrorType type, idx_t column_idx, string csv_row, LinesPerBoundary error_info, idx_t row_byte_position, optional_idx byte_position, const CSVReaderOptions &reader_options, - const string &fixes, const string ¤t_path); + const string &fixes, const String ¤t_path); CSVError(string error_message, CSVErrorType type, LinesPerBoundary error_info); //! Produces error messages for column name -> type mismatch. static CSVError ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names); //! Produces error messages for casting errors static CSVError CastError(const CSVReaderOptions &options, const string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type, const string ¤t_path); + optional_idx byte_position, LogicalTypeId type, const String ¤t_path); //! Produces error for when the line size exceeds the maximum line size option static CSVError LineSizeError(const CSVReaderOptions &options, LinesPerBoundary error_info, string &csv_row, - idx_t byte_position, const string ¤t_path); + idx_t byte_position, const String ¤t_path); //! Produces error for when the state machine reaches an invalid state static CSVError InvalidState(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path); + const String ¤t_path); //! Produces an error message for a dialect sniffing error. static CSVError SniffingError(const CSVReaderOptions &options, const string &search_space, idx_t max_columns_found, SetColumns &set_columns, bool type_detection); @@ -76,17 +76,17 @@ class CSVError { //! Produces error messages for unterminated quoted values static CSVError UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path); + optional_idx byte_position, const String ¤t_path); //! Produces error messages for null_padding option is set, and we have quoted new values in parallel static CSVError NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, - const string ¤t_path); + const String ¤t_path); //! Produces error for incorrect (e.g., smaller and lower than the predefined) number of columns in a CSV Line static CSVError IncorrectColumnAmountError(const CSVReaderOptions &state_machine, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path); + optional_idx byte_position, const String ¤t_path); static CSVError InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path); + const String ¤t_path); idx_t GetBoundaryIndex() const { return error_info.boundary_idx; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp index 324501a3d..5446739ad 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp @@ -59,8 +59,8 @@ class CSVFileScan : public BaseFileReader { void PrepareReader(ClientContext &context, GlobalTableFunctionState &) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate_p) override; double GetProgressInFile(ClientContext &context) override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp index b3d4f9dd5..744710bef 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp @@ -196,7 +196,7 @@ struct CSVReaderOptions { //! Verify options are not conflicting void Verify(MultiFileOptions &file_options); - string ToString(const string ¤t_file_path) const; + string ToString(const String ¤t_file_path) const; //! If the type for column with idx i was manually set bool WasTypeManuallySet(idx_t i) const; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp index 45aeaad9b..f04bbd814 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp @@ -11,6 +11,7 @@ #include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" #include "duckdb/execution/operator/csv_scanner/csv_buffer_manager.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp" +#include "duckdb/common/printer.hpp" namespace duckdb { @@ -129,12 +130,12 @@ class CSVStateMachine { } void Print() const { - std::cout << "State Machine Options" << '\n'; - std::cout << "Delim: " << state_machine_options.delimiter.GetValue() << '\n'; - std::cout << "Quote: " << state_machine_options.quote.GetValue() << '\n'; - std::cout << "Escape: " << state_machine_options.escape.GetValue() << '\n'; - std::cout << "Comment: " << state_machine_options.comment.GetValue() << '\n'; - std::cout << "---------------------" << '\n'; + Printer::Print(OutputStream::STREAM_STDOUT, string("State Machine Options")); + Printer::Print(OutputStream::STREAM_STDOUT, string("Delim: ") + state_machine_options.delimiter.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Quote: ") + state_machine_options.quote.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Escape: ") + state_machine_options.escape.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("Comment: ") + state_machine_options.comment.FormatValue()); + Printer::Print(OutputStream::STREAM_STDOUT, string("---------------------")); } //! The Transition Array is a Finite State Machine //! It holds the transitions of all states, on all 256 possible different characters diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp index ad4f8f1b2..0e2bb1d0a 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp @@ -47,7 +47,7 @@ class StateMachine { //! Hash function used in out state machine cache, it hashes and combines all options used to generate a state machine struct HashCSVStateMachineConfig { - size_t operator()(CSVStateMachineOptions const &config) const noexcept { + hash_t operator()(CSVStateMachineOptions const &config) const noexcept { auto h_delimiter = Hash(config.delimiter.GetValue().c_str()); auto h_quote = Hash(config.quote.GetValue()); auto h_escape = Hash(config.escape.GetValue()); diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp index 5b985e05c..8ab081ae3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp @@ -116,6 +116,7 @@ class CSVSniffer { //! Highest number of columns found idx_t max_columns_found = 0; idx_t max_columns_found_error = 0; + bool best_candidate_is_strict = false; //! Current Candidates being considered vector> candidates; //! Reference to original CSV Options, it will be modified as a result of the sniffer. diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp index bacabfc4f..e2987d955 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp @@ -41,13 +41,15 @@ class FullLinePosition { return {}; } string result; - if (end.buffer_idx == begin.buffer_idx) { - if (buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { + if (end.buffer_idx == begin.buffer_idx || begin.buffer_pos == begin.buffer_size) { + idx_t buffer_idx = end.buffer_idx; + if (buffer_handles.find(buffer_idx) == buffer_handles.end()) { return {}; } - auto buffer = buffer_handles[begin.buffer_idx]->Ptr(); - first_char_nl = buffer[begin.buffer_pos] == '\n' || buffer[begin.buffer_pos] == '\r'; - for (idx_t i = begin.buffer_pos + first_char_nl; i < end.buffer_pos; i++) { + idx_t start_pos = begin.buffer_pos == begin.buffer_size ? 0 : begin.buffer_pos; + auto buffer = buffer_handles[buffer_idx]->Ptr(); + first_char_nl = buffer[start_pos] == '\n' || buffer[start_pos] == '\r'; + for (idx_t i = start_pos + first_char_nl; i < end.buffer_pos; i++) { result += buffer[i]; } } else { @@ -55,6 +57,9 @@ class FullLinePosition { buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { return {}; } + if (begin.buffer_pos >= begin.buffer_size) { + throw InternalException("CSV reader: buffer pos out of range for buffer"); + } auto first_buffer = buffer_handles[begin.buffer_idx]->Ptr(); auto first_buffer_size = buffer_handles[begin.buffer_idx]->actual_size; auto second_buffer = buffer_handles[end.buffer_idx]->Ptr(); @@ -176,7 +181,7 @@ class StringValueResult : public ScannerResult { const shared_ptr &buffer_handle, Allocator &buffer_allocator, idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_handler, CSVIterator &iterator, bool store_line_size, shared_ptr csv_file_scan, idx_t &lines_read, bool sniffing, - string path, idx_t scan_id); + const string &path, idx_t scan_id, bool &used_unstrictness); ~StringValueResult(); @@ -225,6 +230,7 @@ class StringValueResult : public ScannerResult { shared_ptr csv_file_scan; idx_t &lines_read; + bool &used_unstrictness; //! Information regarding projected columns unsafe_unique_array projected_columns; bool projecting_columns = false; @@ -248,7 +254,7 @@ class StringValueResult : public ScannerResult { //! We store borked rows so we can generate multiple errors during flushing unordered_set borked_rows; - const string path; + String path; //! Variable used when trying to figure out where a new segment starts, we must always start from a Valid //! (i.e., non-comment) line. diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp index 2a7425279..ff6365f6b 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_batch_collector.hpp @@ -18,7 +18,7 @@ class PhysicalBatchCollector : public PhysicalResultCollector { PhysicalBatchCollector(PhysicalPlan &physical_plan, PreparedStatementData &data); public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface @@ -44,7 +44,8 @@ class PhysicalBatchCollector : public PhysicalResultCollector { //===--------------------------------------------------------------------===// class BatchCollectorGlobalState : public GlobalSinkState { public: - BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + BatchCollectorGlobalState(ClientContext &context, const PhysicalBatchCollector &op) + : data(context, op.types, op.memory_type) { } mutex glock; @@ -54,7 +55,8 @@ class BatchCollectorGlobalState : public GlobalSinkState { class BatchCollectorLocalState : public LocalSinkState { public: - BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) : data(context, op.types) { + BatchCollectorLocalState(ClientContext &context, const PhysicalBatchCollector &op) + : data(context, op.types, op.memory_type) { } BatchedDataCollection data; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp index 74865c5d5..cb5e8892f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp @@ -26,7 +26,7 @@ class PhysicalBufferedBatchCollector : public PhysicalResultCollector { PhysicalBufferedBatchCollector(PhysicalPlan &physical_plan, PreparedStatementData &data); public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp index 16dfdafb7..fadd37bf0 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_buffered_collector.hpp @@ -20,7 +20,7 @@ class PhysicalBufferedCollector : public PhysicalResultCollector { bool parallel; public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_create_secret.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_create_secret.hpp index 7ebebb51a..d21d8c0f3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_create_secret.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_create_secret.hpp @@ -29,7 +29,8 @@ class PhysicalCreateSecret : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp index 1d8097aa3..b0d83db96 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_explain_analyze.hpp @@ -28,7 +28,8 @@ class PhysicalExplainAnalyze : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp index 2ec2288a4..0c5ab2172 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit.hpp @@ -36,7 +36,8 @@ class PhysicalLimit : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp index 5687674e0..5f9613a3a 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_limit_percent.hpp @@ -34,7 +34,8 @@ class PhysicalLimitPercent : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp index 229c43693..103d19fb0 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_load.hpp @@ -28,7 +28,8 @@ class PhysicalLoad : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp index 2c73eee8d..5ed71542f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_materialized_collector.hpp @@ -21,7 +21,7 @@ class PhysicalMaterializedCollector : public PhysicalResultCollector { bool parallel; public: - unique_ptr GetResult(GlobalSinkState &state) override; + unique_ptr GetResult(GlobalSinkState &state) const override; public: // Sink interface @@ -35,20 +35,4 @@ class PhysicalMaterializedCollector : public PhysicalResultCollector { bool SinkOrderDependent() const override; }; -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MaterializedCollectorGlobalState : public GlobalSinkState { -public: - mutex glock; - unique_ptr collection; - shared_ptr context; -}; - -class MaterializedCollectorLocalState : public LocalSinkState { -public: - unique_ptr collection; - ColumnDataAppendState append_state; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp index 313217c59..77a8c7db7 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_pragma.hpp @@ -29,7 +29,8 @@ class PhysicalPragma : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp index 373b82e5d..5f3e6cb23 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp @@ -11,6 +11,7 @@ #include "duckdb/execution/physical_operator.hpp" #include "duckdb/common/enums/physical_operator_type.hpp" #include "duckdb/main/prepared_statement_data.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { @@ -19,18 +20,19 @@ class PhysicalPrepare : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PREPARE; public: - PhysicalPrepare(PhysicalPlan &physical_plan, string name_p, shared_ptr prepared, + PhysicalPrepare(PhysicalPlan &physical_plan, const std::string &name_p, shared_ptr prepared, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::PREPARE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(std::move(name_p)), prepared(std::move(prepared)) { + name(physical_plan.ArenaRef().MakeString(name_p)), prepared(std::move(prepared)) { } - string name; + String name; shared_ptr prepared; public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp index 7ab91a91f..c38844e59 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reservoir_sample.hpp @@ -29,7 +29,8 @@ class PhysicalReservoirSample : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp index 16157c0f2..aaf593db8 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp @@ -31,7 +31,8 @@ class PhysicalReset : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp index c141dad66..e654a8e9d 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp @@ -13,7 +13,9 @@ #include "duckdb/common/enums/statement_type.hpp" namespace duckdb { + class PreparedStatementData; +class ColumnDataCollection; //! PhysicalResultCollector is an abstract class that is used to generate the final result of a query class PhysicalResultCollector : public PhysicalOperator { @@ -25,6 +27,7 @@ class PhysicalResultCollector : public PhysicalOperator { StatementType statement_type; StatementProperties properties; + QueryResultMemoryType memory_type; PhysicalOperator &plan; vector names; @@ -33,7 +36,7 @@ class PhysicalResultCollector : public PhysicalOperator { public: //! The final method used to fetch the query result from this operator - virtual unique_ptr GetResult(GlobalSinkState &state) = 0; + virtual unique_ptr GetResult(GlobalSinkState &state) const = 0; bool IsSink() const override { return true; @@ -52,6 +55,9 @@ class PhysicalResultCollector : public PhysicalOperator { virtual bool IsStreaming() const { return false; } + +protected: + unique_ptr CreateCollection(ClientContext &context) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp index c7f2fb038..88345b3cc 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/execution/physical_operator.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/execution/physical_plan_generator.hpp" namespace duckdb { @@ -26,24 +27,25 @@ class PhysicalSet : public PhysicalOperator { PhysicalSet(PhysicalPlan &physical_plan, const string &name_p, Value value_p, SetScope scope_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET, {LogicalType::BOOLEAN}, estimated_cardinality), - name(name_p), value(std::move(value_p)), scope(scope_p) { + name(physical_plan.ArenaRef().MakeString(name_p)), value(std::move(value_p)), scope(scope_p) { } public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; } - static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, + static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const String &name, SetScope scope, const Value &value); - static void SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value); + static void SetGenericVariable(ClientContext &context, const String &name, SetScope scope, Value target_value); public: - const string name; + String name; const Value value; const SetScope scope; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp index 4574cd868..2e1a0252f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp @@ -18,11 +18,12 @@ class PhysicalSetVariable : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::SET_VARIABLE; public: - PhysicalSetVariable(PhysicalPlan &physical_plan, string name, idx_t estimated_cardinality); + PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality); public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; @@ -37,7 +38,7 @@ class PhysicalSetVariable : public PhysicalOperator { } public: - const string name; + String name; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp index ea3029ea5..ef495644a 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_transaction.hpp @@ -29,7 +29,8 @@ class PhysicalTransaction : public PhysicalOperator { unique_ptr info; public: - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_update_extensions.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_update_extensions.hpp index dab694531..de2f2a67f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_update_extensions.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_update_extensions.hpp @@ -43,7 +43,8 @@ class PhysicalUpdateExtensions : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp index 4999c9c7f..93ab05a85 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_vacuum.hpp @@ -29,7 +29,8 @@ class PhysicalVacuum : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp index ad7f39223..bbaa04262 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp @@ -64,6 +64,9 @@ struct JoinFilterPushdownInfo { //! Min/Max aggregates vector> min_max_aggregates; + //! Whether the build side has a filter -> we might be able to push down a bloom filter into the probe side + bool build_side_has_filter; + public: unique_ptr GetGlobalState(ClientContext &context, const PhysicalOperator &op) const; unique_ptr GetLocalState(JoinFilterGlobalState &gstate) const; @@ -73,9 +76,22 @@ struct JoinFilterPushdownInfo { unique_ptr Finalize(ClientContext &context, optional_ptr ht, JoinFilterGlobalState &gstate, const PhysicalComparisonJoin &op) const; + unique_ptr FinalizeMinMax(JoinFilterGlobalState &gstate) const; + unique_ptr FinalizeFilters(ClientContext &context, optional_ptr ht, + const PhysicalComparisonJoin &op, unique_ptr final_min_max, + bool is_perfect_hashtable) const; + private: void PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, idx_t filter_col_idx) const; + + void PushBloomFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, + idx_t filter_col_idx) const; + + bool CanUseInFilter(const ClientContext &context, optional_ptr ht, const ExpressionType &cmp) const; + bool CanUseBloomFilter(const ClientContext &context, optional_ptr ht, + const PhysicalComparisonJoin &op, const ExpressionType &cmp, + bool is_perfect_hashtable) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp index b90a26767..6d853e6f8 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/outer_join_marker.hpp @@ -31,7 +31,7 @@ class OuterJoinMarker { public: explicit OuterJoinMarker(bool enabled); - bool Enabled() { + bool Enabled() const { return enabled; } //! Initializes the outer join counter diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp index 0affaf4cc..9e5134481 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp @@ -29,7 +29,7 @@ struct PerfectHashJoinStats { //! PhysicalHashJoin represents a hash loop join between two tables class PerfectHashJoinExecutor { - using PerfectHashTable = vector; + using PerfectHashTable = vector>; public: PerfectHashJoinExecutor(const PhysicalHashJoin &join, JoinHashTable &ht); @@ -64,7 +64,7 @@ class PerfectHashJoinExecutor { //! Build statistics PerfectHashJoinStats perfect_join_statistics; //! Stores the occurrences of each value in the build side - unsafe_unique_array bitmap_build_idx; + ValidityMask bitmap_build_idx; //! Stores the number of unique keys in the build side idx_t unique_keys = 0; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp index 6089a728b..95496fd84 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp @@ -37,18 +37,6 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { // Projection mappings vector right_projection_map; - // Predicate (join conditions that don't reference both sides) - unique_ptr predicate; - -public: - // Operator Interface - unique_ptr GetGlobalOperatorState(ClientContext &context) const override; - unique_ptr GetOperatorState(ExecutionContext &context) const override; - - bool ParallelOperator() const override { - return true; - } - protected: // CachingOperator Interface OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, @@ -59,7 +47,8 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; @@ -83,6 +72,9 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { bool ParallelSink() const override { return true; } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp index 69905e7a8..304479383 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_blockwise_nl_join.hpp @@ -44,7 +44,8 @@ class PhysicalBlockwiseNLJoin : public PhysicalJoin { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return PropagatesBuildSide(join_type); diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp index 46d6b47c3..c8b9c58e7 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp @@ -74,7 +74,8 @@ class PhysicalHashJoin : public PhysicalComparisonJoin { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp index b57fe772d..219b51786 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_iejoin.hpp @@ -41,7 +41,8 @@ class PhysicalIEJoin : public PhysicalRangeJoin { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; @@ -70,10 +71,6 @@ class PhysicalIEJoin : public PhysicalRangeJoin { public: void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; - -private: - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) - void ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp index 25ed9ed06..2536a2016 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp @@ -18,13 +18,16 @@ class PhysicalNestedLoopJoin : public PhysicalComparisonJoin { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::NESTED_LOOP_JOIN; public: - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info); - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality); + // Predicate (join conditions that don't reference both sides) + unique_ptr predicate; + public: // Operator Interface unique_ptr GetOperatorState(ExecutionContext &context) const override; @@ -43,7 +46,8 @@ class PhysicalNestedLoopJoin : public PhysicalComparisonJoin { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return PropagatesBuildSide(join_type); diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp index 4da01aff3..faed6e0c1 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp @@ -46,7 +46,10 @@ class PhysicalPiecewiseMergeJoin : public PhysicalRangeJoin { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return PropagatesBuildSide(join_type); diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp index 11a678c62..51fd38043 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_positional_join.hpp @@ -29,7 +29,8 @@ class PhysicalPositionalJoin : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 4ee6ef557..1edb36ed4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -1,43 +1,42 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// duckdb/execution/operator/join/physical_range_join.hpp // // //===----------------------------------------------------------------------===// #pragma once +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/operator/join/physical_comparison_join.hpp" -#include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" namespace duckdb { -struct GlobalSortState; - //! PhysicalRangeJoin represents one or more inequality range join predicates between //! two tables class PhysicalRangeJoin : public PhysicalComparisonJoin { public: + class GlobalSortedTable; + class LocalSortedTable { public: - LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child); + LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child); - void Sink(DataChunk &input, GlobalSortState &global_sort_state); + void Sink(ExecutionContext &context, DataChunk &input); - inline void Sort(GlobalSortState &global_sort_state) { - local_sort_state.Sort(global_sort_state, true); - } - - //! The hosting operator - const PhysicalRangeJoin &op; + //! The global table we are connected to + GlobalSortedTable &global_table; //! The local sort state - LocalSortState local_sort_state; + unique_ptr local_sink; //! Local copy of the sorting expression executor ExpressionExecutor executor; //! Holds a vector of incoming sorting columns DataChunk keys; + //! The sort data + DataChunk sort_chunk; //! The number of NULL values idx_t has_null; //! The total number of rows @@ -50,45 +49,89 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { class GlobalSortedTable { public: - GlobalSortedTable(ClientContext &context, const vector &orders, RowLayout &payload_layout, - const PhysicalOperator &op); + GlobalSortedTable(ClientContext &client, const vector &orders, + const vector &payload_layout, const PhysicalRangeJoin &op); inline idx_t Count() const { return count; } inline idx_t BlockCount() const { - if (global_sort_state.sorted_blocks.empty()) { - return 0; - } - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - return global_sort_state.sorted_blocks[0]->radix_sorting_data.size(); + return sorted->key_data->ChunkCount(); + } + + inline idx_t BlockStart(idx_t i) const { + return MinValue(i * STANDARD_VECTOR_SIZE, count); + } + + inline idx_t BlockEnd(idx_t i) const { + return BlockStart(i + 1) - 1; } inline idx_t BlockSize(idx_t i) const { - return global_sort_state.sorted_blocks[0]->radix_sorting_data[i]->count; + return i < BlockCount() ? MinValue(STANDARD_VECTOR_SIZE, count - BlockStart(i)) : 0; + } + + inline SortKeyType GetSortKeyType() const { + return sorted->key_data->GetLayout().GetSortKeyType(); } - void Combine(LocalSortedTable <able); void IntializeMatches(); + + //! Combine local states + void Combine(ExecutionContext &context, LocalSortedTable <able); + //! Prepare for sorting. + void Finalize(ClientContext &client, InterruptState &interrupt); + //! Schedules the materialisation process. + void Materialize(Pipeline &pipeline, Event &event); + //! Single-threaded materialisation. + void Materialize(ExecutionContext &context, InterruptState &interrupt); + //! Materialize an empty sorted run. + void MaterializeEmpty(ClientContext &client); + //! Print the table to the console void Print(); - //! Starts the sorting process. - void Finalize(Pipeline &pipeline, Event &event); - //! Schedules tasks to merge sort the current child's data during a Finalize phase - void ScheduleMergeTasks(Pipeline &pipeline, Event &event); + //! Create an iteration state + unique_ptr CreateIteratorState() { + auto state = make_uniq(*sorted->key_data, sorted->payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + iter.SetKeepPinned(true); + iter.SetPinPayload(true); + } + //! Create an iteration state + unique_ptr CreateScanState(ClientContext &client) { + return make_uniq(client, *sort); + } + //! Initialize a payload scanning state + void InitializePayloadState(TupleDataChunkState &state) { + sorted->payload_data->InitializeChunkState(state); + } //! The hosting operator - const PhysicalOperator &op; - GlobalSortState global_sort_state; + const PhysicalRangeJoin &op; + //! The sort description + unique_ptr sort; + //! The shared sort state + unique_ptr global_sink; //! Whether or not the RHS has NULL values atomic has_null; //! The total number of rows in the RHS atomic count; + //! The number of materialisation tasks completed in parallel + atomic tasks_completed; + //! The shared materialisation state + unique_ptr global_source; + //! The materialized data + unique_ptr sorted; //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) unsafe_unique_array found_match; - //! Memory usage per thread - idx_t memory_per_thread; }; public: @@ -106,10 +149,9 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { public: // Gather the result values and slice the payload columns to those values. - // Returns a buffer handle to the pinned heap block (if any) - static BufferHandle SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols = 0); + static void SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, ExternalBlockIteratorState &state, + TupleDataChunkState &chunk_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count, SortedRunScanState &scan_state); // Apply a tail condition to the current selection static idx_t SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel); diff --git a/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp b/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp index 4c0be847f..3b45553c4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/order/physical_order.hpp @@ -36,7 +36,8 @@ class PhysicalOrder : public PhysicalOperator { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp b/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp index 0f67b95e8..dd79b094f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/order/physical_top_n.hpp @@ -36,7 +36,8 @@ class PhysicalTopN : public PhysicalOperator { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp index f376346e3..f45668468 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp @@ -34,7 +34,8 @@ class PhysicalBatchCopyToFile : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp index 8b9beb983..f5b94c46e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp @@ -37,7 +37,8 @@ class PhysicalBatchInsert : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_database.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_database.hpp index d427b9481..73986a842 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_database.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_database.hpp @@ -26,7 +26,8 @@ class PhysicalCopyDatabase : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp index 6b4486811..2e8d381e9 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp @@ -60,7 +60,8 @@ class PhysicalCopyToFile : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp index 0adeb37cd..4372de5a9 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_delete.hpp @@ -33,7 +33,8 @@ class PhysicalDelete : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp index 29274d8fb..29543a64e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_export.hpp @@ -47,7 +47,8 @@ class PhysicalExport : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp index 14680df1e..807b95984 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_insert.hpp @@ -122,7 +122,8 @@ class PhysicalInsert : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp index 9cbba99c1..406493065 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp @@ -57,7 +57,8 @@ class PhysicalMergeInto : public PhysicalOperator { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp index 6bf191861..1ec0b5f07 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_update.hpp @@ -41,7 +41,8 @@ class PhysicalUpdate : public PhysicalOperator { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp index 5a5ca7722..4820db315 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp @@ -46,6 +46,10 @@ class PhysicalTableInOutFunction : public PhysicalOperator { //! Information for WITH ORDINALITY optional_idx ordinality_idx; + OrderPreservationType OperatorOrder() const override { + return function.order_preservation_type; + } + private: //! The table function TableFunction function; diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp index cd50834ef..7dfedef79 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_column_data_scan.hpp @@ -36,7 +36,8 @@ class PhysicalColumnDataScan : public PhysicalOperator { unique_ptr GetGlobalSourceState(ClientContext &context) const override; unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp index 704241e70..1cb09bdec 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_dummy_scan.hpp @@ -22,7 +22,8 @@ class PhysicalDummyScan : public PhysicalOperator { } public: - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp index db352ea78..3f8a72e27 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_empty_result.hpp @@ -22,7 +22,8 @@ class PhysicalEmptyResult : public PhysicalOperator { } public: - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp index bf7350c70..9d7d7b2d4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_positional_scan.hpp @@ -36,7 +36,8 @@ class PhysicalPositionalScan : public PhysicalOperator { unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp index 811a1abda..90b8efece 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp @@ -14,6 +14,7 @@ #include "duckdb/storage/data_table.hpp" #include "duckdb/common/extra_operator_info.hpp" #include "duckdb/common/column_index.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" namespace duckdb { @@ -60,11 +61,16 @@ class PhysicalTableScan : public PhysicalOperator { bool Equals(const PhysicalOperator &other) const override; + OrderPreservationType SourceOrder() const override { + return function.order_preservation_type; + } + public: unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const override; unique_ptr GetGlobalSourceState(ClientContext &context) const override; - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp index af3b51539..a93d3de5c 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_alter.hpp @@ -28,7 +28,8 @@ class PhysicalAlter : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp index 9b4d85f83..b0ea478eb 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_attach.hpp @@ -28,7 +28,8 @@ class PhysicalAttach : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp index e134f835a..b6cf91a46 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_art_index.hpp @@ -45,7 +45,8 @@ class PhysicalCreateARTIndex : public PhysicalOperator { public: //! Source interface, NOP for this operator - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp index 0b8d832ec..0fcaf3563 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_function.hpp @@ -30,7 +30,8 @@ class PhysicalCreateFunction : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp index d76117b5f..36bd41dfa 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_schema.hpp @@ -30,7 +30,8 @@ class PhysicalCreateSchema : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp index 77af76502..947933acf 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_sequence.hpp @@ -30,7 +30,8 @@ class PhysicalCreateSequence : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp index 0b2240cc6..e9163571f 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_table.hpp @@ -29,7 +29,8 @@ class PhysicalCreateTable : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp index 7b44b31a2..e069c7f5a 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_type.hpp @@ -25,7 +25,8 @@ class PhysicalCreateType : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp index ca6490926..9185e8936 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_create_view.hpp @@ -30,7 +30,8 @@ class PhysicalCreateView : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp index 3b180e822..d63f31d7c 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_detach.hpp @@ -27,7 +27,8 @@ class PhysicalDetach : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp index 85cd57bb9..63b9da0e8 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/schema/physical_drop.hpp @@ -28,7 +28,8 @@ class PhysicalDrop : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp index c86a8051e..e21e531f6 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp @@ -48,7 +48,8 @@ class PhysicalRecursiveCTE : public PhysicalOperator { public: // Source interface - SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; bool IsSource() const override { return true; diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp index bdc54a415..ee205bbae 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp @@ -123,7 +123,13 @@ class PhysicalOperator { virtual unique_ptr GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const; virtual unique_ptr GetGlobalSourceState(ClientContext &context) const; - virtual SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const; + +protected: + virtual SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const; + +public: + SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const; virtual OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp index 9d195c8fe..a471c229f 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp @@ -63,6 +63,10 @@ class GlobalOperatorState { DynamicCastCheck(this); return reinterpret_cast(*this); } + + virtual idx_t MaxThreads(idx_t source_max_threads) { + return source_max_threads; + } }; class GlobalSinkState : public StateWithBlockableTasks { diff --git a/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp index 5d9e0aa46..1a1080e7b 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/group_by_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/planner/logical_tokens.hpp" #include "duckdb/planner/joinside.hpp" @@ -152,7 +153,8 @@ class PhysicalPlanGenerator { PhysicalOperator &PlanComparisonJoin(LogicalComparisonJoin &op); PhysicalOperator &PlanDelimJoin(LogicalComparisonJoin &op); PhysicalOperator &ExtractAggregateExpressions(PhysicalOperator &child, vector> &expressions, - vector> &groups); + vector> &groups, + optional_ptr> grouping_sets); private: ClientContext &context; diff --git a/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp b/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp new file mode 100644 index 000000000..eb7886651 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/physical_table_scan_enum.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/physical_table_scan_enum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace duckdb { + +enum class PhysicalTableScanExecutionStrategy : uint8_t { + DEFAULT, + TASK_EXECUTOR, + SYNCHRONOUS, + TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS +}; + +}; // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp index a26772819..666369cc7 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp @@ -242,6 +242,30 @@ class BinaryAggregateHeap { idx_t size; }; +enum class ArgMinMaxNullHandling { IGNORE_ANY_NULL, HANDLE_ARG_NULL, HANDLE_ANY_NULL }; + +struct ArgMinMaxFunctionData : FunctionData { + explicit ArgMinMaxFunctionData(ArgMinMaxNullHandling null_handling_p = ArgMinMaxNullHandling::IGNORE_ANY_NULL, + bool nulls_last_p = true) + : null_handling(null_handling_p), nulls_last(nulls_last_p) { + } + + unique_ptr Copy() const override { + auto copy = make_uniq(); + copy->null_handling = null_handling; + copy->nulls_last = nulls_last; + return std::move(copy); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return other.null_handling == null_handling && other.nulls_last == nulls_last; + } + + ArgMinMaxNullHandling null_handling; + bool nulls_last; +}; + //------------------------------------------------------------------------------ // Specializations for fixed size types, strings, and anything else (using sortkey) //------------------------------------------------------------------------------ @@ -254,7 +278,7 @@ struct MinMaxFixedValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = value; } @@ -263,7 +287,8 @@ struct MinMaxFixedValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -276,7 +301,7 @@ struct MinMaxStringValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { FlatVector::GetData(vector)[idx] = StringVector::AddStringOrBlob(vector, value); } @@ -285,7 +310,8 @@ struct MinMaxStringValue { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format) { + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, + const bool nulls_last) { input.ToUnifiedFormat(count, format); } }; @@ -299,8 +325,9 @@ struct MinMaxFallbackValue { return UnifiedVectorFormat::GetData(format)[idx]; } - static void Assign(Vector &vector, const idx_t idx, const TYPE &value) { - OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::DecodeSortKey(value, vector, idx, modifiers); } @@ -308,14 +335,61 @@ struct MinMaxFallbackValue { return Vector(LogicalTypeId::BLOB); } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format) { - const OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; + const OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); CreateSortKeyHelpers::CreateSortKeyWithValidity(input, extra_state, modifiers, count); input.Flatten(count); extra_state.ToUnifiedFormat(count, format); } }; +template +struct ValueOrNull { + T value; + bool is_valid; + + bool operator==(const ValueOrNull &other) const { + return is_valid == other.is_valid && value == other.value; + } + + bool operator>(const ValueOrNull &other) const { + if (is_valid && other.is_valid) { + return value > other.value; + } + if (!is_valid && !other.is_valid) { + return false; + } + + return is_valid ^ NULLS_LAST; + } +}; + +template +struct MinMaxFixedValueOrNull { + using TYPE = ValueOrNull; + using EXTRA_STATE = bool; + + static TYPE Create(const UnifiedVectorFormat &format, const idx_t idx) { + return TYPE {UnifiedVectorFormat::GetData(format)[idx], format.validity.RowIsValid(idx)}; + } + + static void Assign(Vector &vector, const idx_t idx, const TYPE &value, const bool nulls_last) { + FlatVector::Validity(vector).Set(idx, value.is_valid); + FlatVector::GetData(vector)[idx] = value.value; + } + + static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + return false; + } + + static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + const bool nulls_last) { + input.ToUnifiedFormat(count, format); + } +}; + //------------------------------------------------------------------------------ // MinMaxN Operation (common for both ArgMinMaxN and MinMaxN) //------------------------------------------------------------------------------ @@ -343,7 +417,11 @@ struct MinMaxNOperation { } template - static void Finalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { + static void Finalize(Vector &state_vector, AggregateInputData &input_data, Vector &result, idx_t count, + idx_t offset) { + // We only expect bind data from arg_max, otherwise nulls last is the default + const bool nulls_last = + input_data.bind_data ? input_data.bind_data->Cast().nulls_last : true; UnifiedVectorFormat state_format; state_vector.ToUnifiedFormat(count, state_format); @@ -387,7 +465,7 @@ struct MinMaxNOperation { auto heap = state.heap.SortAndGetHeap(); for (idx_t slot = 0; slot < state.heap.Size(); slot++) { - STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot])); + STATE::VAL_TYPE::Assign(child_data, current_offset++, state.heap.GetValue(heap[slot]), nulls_last); } } diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp index 534465ea5..1ac83bc0a 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -184,6 +184,59 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug distinct_dependent(AggregateDistinctDependent::DISTINCT_DEPENDENT) { } + // clang-format off + bool HasBindCallback() const { return bind != nullptr; } + bind_aggregate_function_t GetBindCallback() const { return bind; } + void SetBindCallback(bind_aggregate_function_t callback) { bind = callback; } + + bool HasStateInitCallback() const { return initialize != nullptr; } + aggregate_initialize_t GetStateInitCallback() const { return initialize; } + void SetStateInitCallback(aggregate_initialize_t callback) { initialize = callback; } + + bool HasStateSizeCallback() const { return state_size != nullptr; } + aggregate_size_t GetStateSizeCallback() const { return state_size; } + void SetStateSizeCallback(aggregate_size_t callback) { state_size = callback; } + + bool HasStateDestructorCallback() const { return destructor != nullptr; } + aggregate_destructor_t GetStateDestructorCallback() const { return destructor; } + void SetStateDestructorCallback(aggregate_destructor_t callback) { destructor = callback; } + + bool HasStateUpdateCallback() const { return update != nullptr; } + aggregate_update_t GetStateUpdateCallback() const { return update; } + void SetStateUpdateCallback(aggregate_update_t callback) { update = callback; } + + bool HasStateSimpleUpdateCallback() const { return simple_update != nullptr; } + aggregate_simple_update_t GetStateSimpleUpdateCallback() const { return simple_update; } + void SetStateSimpleUpdateCallback(aggregate_simple_update_t callback) { simple_update = callback; } + + void SetStateCombineCallback(aggregate_combine_t callback) { combine = callback; } + aggregate_combine_t GetStateCombineCallback() const { return combine; } + bool HasStateCombineCallback() const { return combine != nullptr; } + + void SetStateFinalizeCallback(aggregate_finalize_t callback) { finalize = callback; } + aggregate_finalize_t GetStateFinalizeCallback() const { return finalize; } + bool HasStateFinalizeCallback() const { return finalize != nullptr; } + + bool HasWindowCallback() const { return window != nullptr; } + aggregate_window_t GetWindowCallback() const { return window; } + void SetWindowCallback(aggregate_window_t callback) { window = callback; } + + void SetWindowInitCallback(aggregate_wininit_t callback) { window_init = callback; } + aggregate_wininit_t GetWindowInitCallback() const { return window_init; } + bool HasWindowInitCallback() const { return window_init != nullptr; } + + bool HasStatisticsCallback() const { return statistics != nullptr; } + aggregate_statistics_t GetStatisticsCallback() const { return statistics; } + void SetStatisticsCallback(aggregate_statistics_t callback) { statistics = callback; } + + bool HasSerializationCallbacks() const { return serialize != nullptr && deserialize != nullptr; } + void SetSerializeCallback(aggregate_serialize_t callback) { serialize = callback; } + void SetDeserializeCallback(aggregate_deserialize_t callback) { deserialize = callback; } + aggregate_serialize_t GetSerializeCallback() const { return serialize; } + aggregate_deserialize_t GetDeserializeCallback() const { return deserialize; } + // clang-format on + +public: //! The hashed aggregate state sizing function aggregate_size_t state_size; //! The hashed aggregate state initialization function @@ -211,13 +264,29 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug aggregate_serialize_t serialize; aggregate_deserialize_t deserialize; + //! Whether or not the aggregate is order dependent AggregateOrderDependent order_dependent; //! Whether or not the aggregate is affect by distinct modifiers AggregateDistinctDependent distinct_dependent; + + AggregateOrderDependent GetOrderDependent() const { + return order_dependent; + } + void SetOrderDependent(AggregateOrderDependent value) { + order_dependent = value; + } + AggregateDistinctDependent GetDistinctDependent() const { + return distinct_dependent; + } + void SetDistinctDependent(AggregateDistinctDependent value) { + distinct_dependent = value; + } + //! Additional function info, passed to the bind shared_ptr function_info; +public: bool operator==(const AggregateFunction &rhs) const { return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update && combine == rhs.combine && finalize == rhs.finalize && window == rhs.window; diff --git a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp index d87a3a976..107b482b7 100644 --- a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp @@ -170,6 +170,7 @@ struct DefaultCasts { static BoundCastInfo UnionCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo VariantCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo BignumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp index 2e7fbf68e..9aa105bd3 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp @@ -357,6 +357,9 @@ bool ConvertPrimitiveToVariant(ToVariantSourceData &source, ToVariantGlobalResul case LogicalTypeId::CHAR: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); + case LogicalTypeId::GEOMETRY: + return ConvertPrimitiveTemplated( + source, result, count, selvec, values_index_selvec, empty_payload, is_root); case LogicalTypeId::BLOB: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp index 5a8b088ae..209598a74 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp @@ -98,7 +98,7 @@ bool ConvertStructToVariant(ToVariantSourceData &source, ToVariantGlobalResultDa } } if (WRITE_DATA) { - //! Now forward the selection to point to the next index in the children.values_index + //! Now move the selection forward to write the value_id for the next struct child, for each row for (idx_t i = 0; i < sel.count; i++) { sel.children_selection[i]++; } diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp index 00edc6459..56e32577f 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/to_variant_fwd.hpp @@ -14,6 +14,8 @@ namespace duckdb { namespace variant { +void InitializeOffsets(DataChunk &offsets, idx_t count); + struct OffsetData { public: static uint32_t *GetKeys(DataChunk &offsets) { @@ -110,7 +112,6 @@ template void WriteVariantMetadata(ToVariantGlobalResultData &result, idx_t result_index, uint32_t *values_offsets, uint32_t blob_offset, optional_ptr value_index_selvec, idx_t i, VariantLogicalType type_id) { - auto &values_offset_data = values_offsets[result_index]; if (WRITE_DATA) { auto &variant = result.variant; diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp index 28d9db96b..482c3dcd5 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp @@ -1,99 +1,251 @@ #pragma once #include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_visitor.hpp" namespace duckdb { namespace variant { -static bool VariantIsTrivialPrimitive(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - case VariantLogicalType::INT16: - case VariantLogicalType::INT32: - case VariantLogicalType::INT64: - case VariantLogicalType::INT128: - case VariantLogicalType::UINT8: - case VariantLogicalType::UINT16: - case VariantLogicalType::UINT32: - case VariantLogicalType::UINT64: - case VariantLogicalType::UINT128: - case VariantLogicalType::FLOAT: - case VariantLogicalType::DOUBLE: - case VariantLogicalType::UUID: - case VariantLogicalType::DATE: - case VariantLogicalType::TIME_MICROS: - case VariantLogicalType::TIME_NANOS: - case VariantLogicalType::TIMESTAMP_SEC: - case VariantLogicalType::TIMESTAMP_MILIS: - case VariantLogicalType::TIMESTAMP_MICROS: - case VariantLogicalType::TIMESTAMP_NANOS: - case VariantLogicalType::TIME_MICROS_TZ: - case VariantLogicalType::TIMESTAMP_MICROS_TZ: - case VariantLogicalType::INTERVAL: - return true; - default: - return false; +namespace { + +struct AnalyzeState { +public: + explicit AnalyzeState(uint32_t &children_offset) : children_offset(children_offset) { } -} -static uint32_t VariantTrivialPrimitiveSize(VariantLogicalType type) { - switch (type) { - case VariantLogicalType::INT8: - return sizeof(int8_t); - case VariantLogicalType::INT16: - return sizeof(int16_t); - case VariantLogicalType::INT32: - return sizeof(int32_t); - case VariantLogicalType::INT64: - return sizeof(int64_t); - case VariantLogicalType::INT128: - return sizeof(hugeint_t); - case VariantLogicalType::UINT8: - return sizeof(uint8_t); - case VariantLogicalType::UINT16: - return sizeof(uint16_t); - case VariantLogicalType::UINT32: - return sizeof(uint32_t); - case VariantLogicalType::UINT64: - return sizeof(uint64_t); - case VariantLogicalType::UINT128: - return sizeof(uhugeint_t); - case VariantLogicalType::FLOAT: +public: + uint32_t &children_offset; +}; + +struct WriteState { +public: + WriteState(uint32_t &keys_offset, uint32_t &children_offset, uint32_t &blob_offset, data_ptr_t blob_data, + uint32_t &blob_size) + : keys_offset(keys_offset), children_offset(children_offset), blob_offset(blob_offset), blob_data(blob_data), + blob_size(blob_size) { + } + +public: + inline data_ptr_t GetDestination() { + return blob_data + blob_offset + blob_size; + } + +public: + uint32_t &keys_offset; + uint32_t &children_offset; + uint32_t &blob_offset; + data_ptr_t blob_data; + uint32_t &blob_size; +}; + +struct VariantToVariantSizeAnalyzer { + using result_type = uint32_t; + + static uint32_t VisitNull(AnalyzeState &state) { + return 0; + } + static uint32_t VisitBoolean(bool, AnalyzeState &state) { + return 0; + } + + template + static uint32_t VisitInteger(T, AnalyzeState &state) { + return sizeof(T); + } + + static uint32_t VisitFloat(float, AnalyzeState &state) { return sizeof(float); - case VariantLogicalType::DOUBLE: + } + static uint32_t VisitDouble(double, AnalyzeState &state) { return sizeof(double); - case VariantLogicalType::UUID: + } + static uint32_t VisitUUID(hugeint_t, AnalyzeState &state) { return sizeof(hugeint_t); - case VariantLogicalType::DATE: + } + static uint32_t VisitDate(date_t, AnalyzeState &state) { return sizeof(int32_t); - case VariantLogicalType::TIME_MICROS: + } + static uint32_t VisitInterval(interval_t, AnalyzeState &state) { + return sizeof(interval_t); + } + + static uint32_t VisitTime(dtime_t, AnalyzeState &state) { return sizeof(dtime_t); - case VariantLogicalType::TIME_NANOS: + } + static uint32_t VisitTimeNanos(dtime_ns_t, AnalyzeState &state) { return sizeof(dtime_ns_t); - case VariantLogicalType::TIMESTAMP_SEC: + } + static uint32_t VisitTimeTZ(dtime_tz_t, AnalyzeState &state) { + return sizeof(dtime_tz_t); + } + static uint32_t VisitTimestampSec(timestamp_sec_t, AnalyzeState &state) { return sizeof(timestamp_sec_t); - case VariantLogicalType::TIMESTAMP_MILIS: + } + static uint32_t VisitTimestampMs(timestamp_ms_t, AnalyzeState &state) { return sizeof(timestamp_ms_t); - case VariantLogicalType::TIMESTAMP_MICROS: + } + static uint32_t VisitTimestamp(timestamp_t, AnalyzeState &state) { return sizeof(timestamp_t); - case VariantLogicalType::TIMESTAMP_NANOS: + } + static uint32_t VisitTimestampNanos(timestamp_ns_t, AnalyzeState &state) { return sizeof(timestamp_ns_t); - case VariantLogicalType::TIME_MICROS_TZ: - return sizeof(dtime_tz_t); - case VariantLogicalType::TIMESTAMP_MICROS_TZ: + } + static uint32_t VisitTimestampTZ(timestamp_tz_t, AnalyzeState &state) { return sizeof(timestamp_tz_t); - case VariantLogicalType::INTERVAL: - return sizeof(interval_t); - default: - throw InternalException("VariantLogicalType '%s' is not a trivial primitive", EnumUtil::ToString(type)); } -} + + static uint32_t VisitString(const string_t &str, AnalyzeState &state) { + auto length = static_cast(str.GetSize()); + return GetVarintSize(length) + length; + } + + static uint32_t VisitBlob(const string_t &blob, AnalyzeState &state) { + return VisitString(blob, state); + } + static uint32_t VisitBignum(const string_t &bignum, AnalyzeState &state) { + return VisitString(bignum, state); + } + static uint32_t VisitGeometry(const string_t &geom, AnalyzeState &state) { + return VisitString(geom, state); + } + static uint32_t VisitBitstring(const string_t &bits, AnalyzeState &state) { + return VisitString(bits, state); + } + + template + static uint32_t VisitDecimal(T, uint32_t width, uint32_t scale, AnalyzeState &state) { + uint32_t size = GetVarintSize(width) + GetVarintSize(scale); + size += sizeof(T); + return size; + } + + static uint32_t VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + AnalyzeState &state) { + uint32_t size = GetVarintSize(nested_data.child_count); + if (nested_data.child_count) { + size += GetVarintSize(nested_data.children_idx + state.children_offset); + } + return size; + } + + static uint32_t VisitObject(const UnifiedVariantVectorData &variant, idx_t row, + const VariantNestedData &nested_data, AnalyzeState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static uint32_t VisitDefault(VariantLogicalType type_id, const_data_ptr_t, AnalyzeState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +struct VariantToVariantDataWriter { + using result_type = void; + + static void VisitNull(WriteState &state) { + return; + } + static void VisitBoolean(bool, WriteState &state) { + return; + } + + template + static void VisitInteger(T val, WriteState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDouble(double val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitUUID(hugeint_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitDate(date_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitInterval(interval_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTime(dtime_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeNanos(dtime_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimeTZ(dtime_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampSec(timestamp_sec_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampMs(timestamp_ms_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestamp(timestamp_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampNanos(timestamp_ns_t val, WriteState &state) { + VisitInteger(val, state); + } + static void VisitTimestampTZ(timestamp_tz_t val, WriteState &state) { + VisitInteger(val, state); + } + + static void VisitString(const string_t &str, WriteState &state) { + auto length = str.GetSize(); + state.blob_size += VarintEncode(length, state.GetDestination()); + memcpy(state.GetDestination(), str.GetData(), length); + state.blob_size += length; + } + static void VisitBlob(const string_t &blob, WriteState &state) { + return VisitString(blob, state); + } + static void VisitBignum(const string_t &bignum, WriteState &state) { + return VisitString(bignum, state); + } + static void VisitGeometry(const string_t &geom, WriteState &state) { + return VisitString(geom, state); + } + static void VisitBitstring(const string_t &bits, WriteState &state) { + return VisitString(bits, state); + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, WriteState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + state.blob_size += VarintEncode(nested_data.child_count, state.GetDestination()); + if (nested_data.child_count) { + //! NOTE: The 'child_index' stored in the OBJECT/ARRAY data could require more bits + //! That's the reason we have to rewrite the data in VARIANT->VARIANT cast + state.blob_size += VarintEncode(nested_data.children_idx + state.children_offset, state.GetDestination()); + } + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + WriteState &state) { + return VisitArray(variant, row, nested_data, state); + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, WriteState &) { + throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(type_id)); + } +}; + +} // namespace template bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalResultData &result_data, idx_t count, optional_ptr selvec, optional_ptr values_index_selvec, const bool is_root) { - auto keys_offset_data = OffsetData::GetKeys(result_data.offsets); auto children_offset_data = OffsetData::GetChildren(result_data.offsets); auto values_offset_data = OffsetData::GetValues(result_data.offsets); @@ -168,99 +320,26 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe } } - auto source_blob_data = const_data_ptr_cast(source.GetData(source_index).GetData()); - - //! Then write all values auto source_values_list_entry = source.GetValuesListEntry(source_index); - for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; - source_value_index++) { - auto source_type_id = source.GetTypeId(source_index, source_value_index); - auto source_byte_offset = source.GetByteOffset(source_index, source_value_index); - - //! NOTE: we have to deserialize these in both passes - //! because to figure out the size of the 'data' that is added by the VARIANT, we have to traverse the - //! VARIANT solely because the 'child_index' stored in the OBJECT/ARRAY data could require more bits - WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, - nullptr, 0, source_type_id); - - if (source_type_id == VariantLogicalType::ARRAY || source_type_id == VariantLogicalType::OBJECT) { - auto source_nested_data = VariantUtils::DecodeNestedData(source, source_index, source_value_index); - if (WRITE_DATA) { - VarintEncode(source_nested_data.child_count, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(source_nested_data.child_count); - if (source_nested_data.child_count) { - auto new_child_index = source_nested_data.children_idx + children_offset; - if (WRITE_DATA) { - VarintEncode(new_child_index, blob_data + blob_offset + blob_size); - } - blob_size += GetVarintSize(new_child_index); - } - } else if (source_type_id == VariantLogicalType::VARIANT_NULL || - source_type_id == VariantLogicalType::BOOL_FALSE || - source_type_id == VariantLogicalType::BOOL_TRUE) { - // no-op - } else if (source_type_id == VariantLogicalType::DECIMAL) { - auto decimal_blob_data = source_blob_data + source_byte_offset; - auto width = static_cast(VarintDecode(decimal_blob_data)); - auto width_varint_size = GetVarintSize(width); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - width_varint_size, - width_varint_size); - } - blob_size += width_varint_size; - auto scale = static_cast(VarintDecode(decimal_blob_data)); - auto scale_varint_size = GetVarintSize(scale); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data - scale_varint_size, - scale_varint_size); - } - blob_size += scale_varint_size; - - if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(hugeint_t)); - } - blob_size += sizeof(hugeint_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int64_t)); - } - blob_size += sizeof(int64_t); - } else if (width > DecimalWidth::max) { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int32_t)); - } - blob_size += sizeof(int32_t); - } else { - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, decimal_blob_data, sizeof(int16_t)); - } - blob_size += sizeof(int16_t); - } - } else if (source_type_id == VariantLogicalType::BITSTRING || - source_type_id == VariantLogicalType::BIGNUM || source_type_id == VariantLogicalType::VARCHAR || - source_type_id == VariantLogicalType::BLOB) { - auto str_blob_data = source_blob_data + source_byte_offset; - auto str_length = VarintDecode(str_blob_data); - auto str_length_varint_size = GetVarintSize(str_length); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data - str_length_varint_size, - str_length_varint_size); - } - blob_size += str_length_varint_size; - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, str_blob_data, str_length); - } - blob_size += str_length; - } else if (VariantIsTrivialPrimitive(source_type_id)) { - auto size = VariantTrivialPrimitiveSize(source_type_id); - if (WRITE_DATA) { - memcpy(blob_data + blob_offset + blob_size, source_blob_data + source_byte_offset, size); - } - blob_size += size; - } else { - throw InternalException("Unrecognized VariantLogicalType: %s", EnumUtil::ToString(source_type_id)); + + if (WRITE_DATA) { + WriteState write_state(keys_offset, children_offset, blob_offset, blob_data, blob_size); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + auto source_type_id = source.GetTypeId(source_index, source_value_index); + WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, + nullptr, 0, source_type_id); + + VariantVisitor::Visit(source, source_index, source_value_index, + write_state); + } + } else { + AnalyzeState analyze_state(children_offset); + for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; + source_value_index++) { + values_offset_data[result_index]++; + blob_size += VariantVisitor::Visit(source, source_index, + source_value_index, analyze_state); } } diff --git a/src/duckdb/src/include/duckdb/function/compression/compression.hpp b/src/duckdb/src/include/duckdb/function/compression/compression.hpp index 337b0d19c..07647a7af 100644 --- a/src/duckdb/src/include/duckdb/function/compression/compression.hpp +++ b/src/duckdb/src/include/duckdb/function/compression/compression.hpp @@ -16,6 +16,8 @@ namespace duckdb { struct ConstantFun { static CompressionFunction GetFunction(PhysicalType type); static bool TypeIsSupported(const PhysicalType physical_type); + static void FiltersNullValues(const LogicalType &type, const TableFilter &filter, bool &filters_nulls, + bool &filters_valid_values, TableFilterState &filter_state); }; struct UncompressedFun { diff --git a/src/duckdb/src/include/duckdb/function/compression_function.hpp b/src/duckdb/src/include/duckdb/function/compression_function.hpp index 64b1c2a58..c610e2319 100644 --- a/src/duckdb/src/include/duckdb/function/compression_function.hpp +++ b/src/duckdb/src/include/duckdb/function/compression_function.hpp @@ -17,6 +17,7 @@ #include "duckdb/storage/data_pointer.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/storage/block_manager.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/storage_lock.hpp" namespace duckdb { @@ -28,7 +29,6 @@ class SegmentStatistics; class TableFilter; struct TableFilterState; struct ColumnSegmentState; - struct ColumnFetchState; struct ColumnScanState; struct PrefetchState; @@ -109,11 +109,6 @@ struct CompressedSegmentState { return ""; } // LCOV_EXCL_STOP - //! Get the block ids of additional pages created by the segment - virtual vector GetAdditionalBlocks() const { // LCOV_EXCL_START - return vector(); - } // LCOV_EXCL_STOP - template TARGET &Cast() { DynamicCastCheck(this); @@ -174,7 +169,8 @@ typedef void (*compression_compress_finalize_t)(CompressionState &state); // Uncompress / Scan //===--------------------------------------------------------------------===// typedef void (*compression_init_prefetch_t)(ColumnSegment &segment, PrefetchState &prefetch_state); -typedef unique_ptr (*compression_init_segment_scan_t)(ColumnSegment &segment); +typedef unique_ptr (*compression_init_segment_scan_t)(const QueryContext &context, + ColumnSegment &segment); //! Function prototype used for reading an entire vector (STANDARD_VECTOR_SIZE) typedef void (*compression_scan_vector_t)(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, @@ -205,7 +201,7 @@ typedef unique_ptr (*compression_init_append_t)(ColumnSe typedef idx_t (*compression_append_t)(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count); typedef idx_t (*compression_finalize_append_t)(ColumnSegment &segment, SegmentStatistics &stats); -typedef void (*compression_revert_append_t)(ColumnSegment &segment, idx_t start_row); +typedef void (*compression_revert_append_t)(ColumnSegment &segment, idx_t new_count); //===--------------------------------------------------------------------===// // Serialization (optional) @@ -215,13 +211,14 @@ typedef unique_ptr (*compression_serialize_state_t)(ColumnSe //! Function prototype for deserializing the segment state typedef unique_ptr (*compression_deserialize_state_t)(Deserializer &deserializer); //! Function prototype for cleaning up the segment state when the column data is dropped -typedef void (*compression_cleanup_state_t)(ColumnSegment &segment); +typedef void (*compression_visit_block_ids_t)(const ColumnSegment &segment, BlockIdVisitor &visitor); //===--------------------------------------------------------------------===// // GetSegmentInfo (optional) //===--------------------------------------------------------------------===// //! Function prototype for retrieving segment information straight from the column segment -typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(ColumnSegment &segment); +typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(QueryContext context, + ColumnSegment &segment); enum class CompressionValidity : uint8_t { REQUIRES_VALIDITY, NO_VALIDITY_REQUIRED }; @@ -239,7 +236,7 @@ class CompressionFunction { compression_revert_append_t revert_append = nullptr, compression_serialize_state_t serialize_state = nullptr, compression_deserialize_state_t deserialize_state = nullptr, - compression_cleanup_state_t cleanup_state = nullptr, + compression_visit_block_ids_t visit_block_ids = nullptr, compression_init_prefetch_t init_prefetch = nullptr, compression_select_t select = nullptr, compression_filter_t filter = nullptr) : type(type), data_type(data_type), init_analyze(init_analyze), analyze(analyze), final_analyze(final_analyze), @@ -247,7 +244,7 @@ class CompressionFunction { init_prefetch(init_prefetch), init_scan(init_scan), scan_vector(scan_vector), scan_partial(scan_partial), select(select), filter(filter), fetch_row(fetch_row), skip(skip), init_segment(init_segment), init_append(init_append), append(append), finalize_append(finalize_append), revert_append(revert_append), - serialize_state(serialize_state), deserialize_state(deserialize_state), cleanup_state(cleanup_state) { + serialize_state(serialize_state), deserialize_state(deserialize_state), visit_block_ids(visit_block_ids) { } //! Compression type @@ -317,8 +314,8 @@ class CompressionFunction { compression_serialize_state_t serialize_state; //! Deserialize the segment state to the metadata (optional) compression_deserialize_state_t deserialize_state; - //! Cleanup the segment state (optional) - compression_cleanup_state_t cleanup_state; + //! Iterate over any extra block ids used by the compression algorithm (optional) + compression_visit_block_ids_t visit_block_ids; // Get Segment Info //! This is only necessary if you want to convey more information about the segment in the 'pragma_storage_info' diff --git a/src/duckdb/src/include/duckdb/function/copy_function.hpp b/src/duckdb/src/include/duckdb/function/copy_function.hpp index cfd379c0a..1b035ebd1 100644 --- a/src/duckdb/src/include/duckdb/function/copy_function.hpp +++ b/src/duckdb/src/include/duckdb/function/copy_function.hpp @@ -23,6 +23,21 @@ class ColumnDataCollection; class ExecutionContext; class PhysicalOperatorLogger; +struct CopyFunctionInfo { + virtual ~CopyFunctionInfo() = default; + + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } +}; + struct LocalFunctionData { virtual ~LocalFunctionData() = default; @@ -69,11 +84,12 @@ struct PreparedBatchData { }; struct CopyFunctionBindInput { - explicit CopyFunctionBindInput(const CopyInfo &info_p) : info(info_p) { + explicit CopyFunctionBindInput(const CopyInfo &info_p, shared_ptr function_info = nullptr) + : info(info_p), function_info(std::move(function_info)) { } const CopyInfo &info; - + shared_ptr function_info; string file_extension; }; @@ -199,6 +215,9 @@ class CopyFunction : public Function { // NOLINT: work-around bug in clang-tidy TableFunction copy_from_function; string extension; + + //! Additional function info, passed to the bind + shared_ptr function_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp index 0ce926c1f..b2b5c08c3 100644 --- a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp +++ b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp @@ -48,7 +48,7 @@ struct OrderModifiers { struct CreateSortKeyHelpers { static void CreateSortKey(DataChunk &input, const vector &modifiers, Vector &result); static void CreateSortKey(Vector &input, idx_t input_count, OrderModifiers modifiers, Vector &result); - static void DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, OrderModifiers modifiers); + static idx_t DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, OrderModifiers modifiers); static void DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, const vector &modifiers); static void CreateSortKeyWithValidity(Vector &input, Vector &result, const OrderModifiers &modifiers, diff --git a/src/duckdb/src/include/duckdb/function/function.hpp b/src/duckdb/src/include/duckdb/function/function.hpp index 587216421..bd9960319 100644 --- a/src/duckdb/src/include/duckdb/function/function.hpp +++ b/src/duckdb/src/include/duckdb/function/function.hpp @@ -175,6 +175,55 @@ class BaseScalarFunction : public SimpleFunction { FunctionErrors errors = FunctionErrors::CANNOT_ERROR); DUCKDB_API ~BaseScalarFunction() override; +public: + void SetReturnType(LogicalType return_type_p) { + return_type = std::move(return_type_p); + } + const LogicalType &GetReturnType() const { + return return_type; + } + LogicalType &GetReturnType() { + return return_type; + } + + FunctionStability GetStability() const { + return stability; + } + void SetStability(FunctionStability stability_p) { + stability = stability_p; + } + + FunctionNullHandling GetNullHandling() const { + return null_handling; + } + void SetNullHandling(FunctionNullHandling null_handling_p) { + null_handling = null_handling_p; + } + + FunctionErrors GetErrorMode() const { + return errors; + } + void SetErrorMode(FunctionErrors errors_p) { + errors = errors_p; + } + + //! Set this functions error-mode as fallible (can throw runtime errors) + void SetFallible() { + errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; + } + //! Set this functions stability as volatile (can not be cached per row) + void SetVolatile() { + stability = FunctionStability::VOLATILE; + } + + void SetCollationHandling(FunctionCollationHandling collation_handling_p) { + collation_handling = collation_handling_p; + } + FunctionCollationHandling GetCollationHandling() const { + return collation_handling; + } + +public: //! Return type of the function LogicalType return_type; //! The stability of the function (see FunctionStability enum for more info) diff --git a/src/duckdb/src/include/duckdb/function/function_binder.hpp b/src/duckdb/src/include/duckdb/function/function_binder.hpp index 6eba740ab..6b43a0777 100644 --- a/src/duckdb/src/include/duckdb/function/function_binder.hpp +++ b/src/duckdb/src/include/duckdb/function/function_binder.hpp @@ -70,7 +70,8 @@ class FunctionBinder { AggregateType aggr_type = AggregateType::NON_DISTINCT); DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, - const vector> &groups); + const vector> &groups, + optional_ptr> grouping_sets); DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr); //! Cast a set of expressions to the arguments of this function diff --git a/src/duckdb/src/include/duckdb/function/function_serialization.hpp b/src/duckdb/src/include/duckdb/function/function_serialization.hpp index d7d3480ac..58fc9ae50 100644 --- a/src/duckdb/src/include/duckdb/function/function_serialization.hpp +++ b/src/duckdb/src/include/duckdb/function/function_serialization.hpp @@ -29,12 +29,13 @@ class FunctionSerializer { // the fields are present, they will be used. serializer.WritePropertyWithDefault(505, "catalog_name", function.catalog_name, ""); serializer.WritePropertyWithDefault(506, "schema_name", function.schema_name, ""); - bool has_serialize = function.serialize; + + bool has_serialize = function.HasSerializationCallbacks(); serializer.WriteProperty(503, "has_serialize", has_serialize); if (has_serialize) { serializer.WriteObject(504, "function_data", - [&](Serializer &obj) { function.serialize(obj, bind_info, function); }); - D_ASSERT(function.deserialize); + [&](Serializer &obj) { function.GetSerializeCallback()(obj, bind_info, function); }); + D_ASSERT(function.GetDeserializeCallback()); } } @@ -94,13 +95,13 @@ class FunctionSerializer { template static unique_ptr FunctionDeserialize(Deserializer &deserializer, FUNC &function) { - if (!function.deserialize) { + if (!function.HasSerializationCallbacks()) { throw SerializationException("Function requires deserialization but no deserialization function for %s", function.name); } unique_ptr result; deserializer.ReadObject(504, "function_data", - [&](Deserializer &obj) { result = function.deserialize(obj, function); }); + [&](Deserializer &obj) { result = function.GetDeserializeCallback()(obj, function); }); return result; } @@ -156,15 +157,14 @@ class FunctionSerializer { bind_data = FunctionDeserialize(deserializer, function); deserializer.Unset(); } else { - FunctionBinder binder(context); // Resolve templates binder.ResolveTemplateTypes(function, children); - if (function.bind) { + if (function.HasBindCallback()) { try { - bind_data = function.bind(context, function, children); + bind_data = function.GetBindCallback()(context, function, children); } catch (std::exception &ex) { ErrorData error(ex); throw SerializationException("Error during bind of function in deserialization: %s", @@ -178,8 +178,8 @@ class FunctionSerializer { binder.CastToFunctionArguments(function, children); } - if (TypeRequiresAssignment(function.return_type)) { - function.return_type = std::move(return_type); + if (TypeRequiresAssignment(function.GetReturnType())) { + function.SetReturnType(std::move(return_type)); } return make_pair(std::move(function), std::move(bind_data)); } diff --git a/src/duckdb/src/include/duckdb/function/partition_stats.hpp b/src/duckdb/src/include/duckdb/function/partition_stats.hpp index d703ddbe1..737bb3064 100644 --- a/src/duckdb/src/include/duckdb/function/partition_stats.hpp +++ b/src/duckdb/src/include/duckdb/function/partition_stats.hpp @@ -9,6 +9,8 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/optional_idx.hpp" namespace duckdb { @@ -22,15 +24,23 @@ enum class TablePartitionInfo : uint8_t { enum class CountType { COUNT_EXACT, COUNT_APPROXIMATE }; +struct PartitionRowGroup { + virtual ~PartitionRowGroup() = default; + virtual unique_ptr GetColumnStatistics(column_t column_id) = 0; + virtual bool MinMaxIsExact(const BaseStatistics &stats) = 0; +}; + struct PartitionStatistics { PartitionStatistics(); //! The row id start - idx_t row_start; + optional_idx row_start; //! The amount of rows in the partition idx_t count; //! Whether or not the count is exact or approximate CountType count_type; + //! Optional accessor for row group statistics + shared_ptr partition_row_group; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp new file mode 100644 index 000000000..7a15ba00c --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// function/scalar/geometry_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct StGeomfromwkbFun { + static constexpr const char *Name = "st_geomfromwkb"; + static constexpr const char *Parameters = "wkb"; + static constexpr const char *Description = "Creates a geometry from Well-Known Binary (WKB) representation"; + static constexpr const char *Example = "ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000')"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAswkbFun { + static constexpr const char *Name = "st_aswkb"; + static constexpr const char *Parameters = "geom"; + static constexpr const char *Description = "Returns the Well-Known Binary (WKB) representation of the geometry"; + static constexpr const char *Example = "st_aswkb(ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000000'))"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAsbinaryFun { + using ALIAS = StAswkbFun; + + static constexpr const char *Name = "st_asbinary"; +}; + +struct StAstextFun { + static constexpr const char *Name = "st_astext"; + static constexpr const char *Parameters = "geom"; + static constexpr const char *Description = "Returns the Well-Known Text (WKT) representation of the geometry"; + static constexpr const char *Example = "ST_AsText(ST_GeomFromWKB(X'01010000000000000000000000000000000000000000000000'))"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StAswktFun { + using ALIAS = StAstextFun; + + static constexpr const char *Name = "st_aswkt"; +}; + +struct StIntersectsExtentFun { + static constexpr const char *Name = "st_intersects_extent"; + static constexpr const char *Parameters = "geom1,geom2"; + static constexpr const char *Description = "Returns true if the geometries bounding boxes intersect"; + static constexpr const char *Example = "'POINT(5 5)'::GEOMETRY && 'LINESTRING(0 0, 10 20)'::GEOMETRY;"; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct StIntersectsExtentFunAlias { + using ALIAS = StIntersectsExtentFun; + + static constexpr const char *Name = "&&"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/list_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/list_functions.hpp index fca0ca690..3211731c6 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/list_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/list_functions.hpp @@ -119,6 +119,22 @@ struct ArrayZipFun { static constexpr const char *Name = "array_zip"; }; +struct ListIntersectFun { + static constexpr const char *Name = "list_intersect"; + static constexpr const char *Parameters = "list1,list2"; + static constexpr const char *Description = "Returns a list containing the distinct elements that are present in both `list1` and `list2`."; + static constexpr const char *Example = "list_intersect([1, 2, 3], [2, 3, 4])"; + static constexpr const char *Categories = "list"; + + static ScalarFunction GetFunction(); +}; + +struct ArrayIntersectFun { + using ALIAS = ListIntersectFun; + + static constexpr const char *Name = "array_intersect"; +}; + struct ListExtractFun { static constexpr const char *Name = "list_extract"; static constexpr const char *Parameters = "list,index"; diff --git a/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp index 5ac80ab08..2ad1e694b 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/regexp.hpp @@ -20,6 +20,11 @@ namespace regexp_util { bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string); void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace = nullptr); void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace = nullptr); +void ParseGroupNameList(ClientContext &context, const string &function_name, Expression &group_expr, + const string &pattern_string, RE2::Options &options, bool require_constant_pattern, + vector &out_names, child_list_t &out_struct_children); + +idx_t AdvanceOneUTF8Basic(const duckdb_re2::StringPiece &input, idx_t base); inline duckdb_re2::StringPiece CreateStringPiece(const string_t &input) { return duckdb_re2::StringPiece(input.GetData(), input.GetSize()); @@ -53,6 +58,33 @@ struct RegexpBaseBindData : public FunctionData { bool Equals(const FunctionData &other_p) const override; }; +struct RegexpExtractAllStructBindData : public RegexpBaseBindData { + RegexpExtractAllStructBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, + vector group_names) + : RegexpBaseBindData(options, std::move(constant_string), constant_pattern), + group_names(std::move(group_names)) { + } + + vector group_names; // order preserved + + unique_ptr Copy() const override { + return make_uniq(options, constant_string, constant_pattern, group_names); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return RegexpBaseBindData::Equals(other) && group_names == other.group_names; + } +}; + +struct RegexpExtractAllStruct { + static void Execute(DataChunk &args, ExpressionState &state, Vector &result); + static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); + static unique_ptr InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data); +}; + struct RegexpMatchesBindData : public RegexpBaseBindData { RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern); RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string, bool constant_pattern, diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp index 6408639ec..c318a9236 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp @@ -25,11 +25,21 @@ struct VariantExtractFun { static ScalarFunctionSet GetFunctions(); }; +struct VariantNormalizeFun { + static constexpr const char *Name = "variant_normalize"; + static constexpr const char *Parameters = "input_variant"; + static constexpr const char *Description = "Normalizes the `input_variant` to a canonical representation."; + static constexpr const char *Example = "variant_normalize({'b': [1,2,3], 'a': 42})::VARIANT)"; + static constexpr const char *Categories = "variant"; + + static ScalarFunction GetFunction(); +}; + struct VariantTypeofFun { static constexpr const char *Name = "variant_typeof"; static constexpr const char *Parameters = "input_variant"; static constexpr const char *Description = "Returns the internal type of the `input_variant`."; - static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3])::VARIANT)"; + static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3]})::VARIANT)"; static constexpr const char *Categories = "variant"; static ScalarFunction GetFunction(); diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index f0c4cb82b..1c20b19c0 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -66,20 +66,25 @@ struct VariantUtils { uint32_t value_index); DUCKDB_API static VariantNestedData DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index); + DUCKDB_API static string_t DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t value_index); DUCKDB_API static vector GetObjectKeys(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data); - DUCKDB_API static VariantChildDataCollectionResult FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, - optional_idx row, SelectionVector &res, - VariantNestedData *nested_data, idx_t count); + DUCKDB_API static void FindChildValues(const UnifiedVariantVectorData &variant, + const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count); DUCKDB_API static VariantNestedDataCollectionResult CollectNestedData(const UnifiedVariantVectorData &variant, VariantLogicalType expected_type, const SelectionVector &sel, idx_t count, optional_idx row, idx_t offset, VariantNestedData *child_data, ValidityMask &validity); DUCKDB_API static vector ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, idx_t count, optional_idx row); - DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, idx_t values_idx); + DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t values_idx); DUCKDB_API static bool Verify(Vector &variant, const SelectionVector &sel_p, idx_t count); + DUCKDB_API static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, + SelectionVector &sel, idx_t sel_size); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar_function.hpp b/src/duckdb/src/include/duckdb/function/scalar_function.hpp index e4c4cdfd2..283b9c140 100644 --- a/src/duckdb/src/include/duckdb/function/scalar_function.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar_function.hpp @@ -142,6 +142,63 @@ class ScalarFunction : public BaseScalarFunction { // NOLINT: work-around bug in FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, bind_lambda_function_t bind_lambda = nullptr); + // clang-format off + // Keep these on one-line for readability + bool HasFunctionCallback() const { return function != nullptr; } + scalar_function_t GetFunctionCallback() const { return function; } + void SetFunctionCallback(scalar_function_t callback) { function = std::move(callback); } + + bool HasBindCallback() const { return bind != nullptr; }; + bind_scalar_function_t GetBindCallback() const { return bind; }; + void SetBindCallback(bind_scalar_function_t callback) { bind = callback; } + + bool HasBindExtendedCallback() const { return bind_extended != nullptr; } + bind_scalar_function_extended_t GetBindExtendedCallback() const { return bind_extended; } + void SetBindExtendedCallback(bind_scalar_function_extended_t callback) { bind_extended = callback; } + + bool HasBindLambdaCallback() const { return bind_lambda != nullptr; } + bind_lambda_function_t GetBindLambdaCallback() const { return bind_lambda; } + void SetBindLambdaCallback(bind_lambda_function_t callback) { bind_lambda = callback; } + + bool HasBindExpressionCallback() const { return bind_expression != nullptr; } + function_bind_expression_t GetBindExpressionCallback() const { return bind_expression; } + void SetBindExpressionCallback(function_bind_expression_t callback) { bind_expression = callback; } + + bool HasInitStateCallback() const { return init_local_state != nullptr; } + init_local_state_t GetInitStateCallback() const { return init_local_state; } + void SetInitStateCallback(init_local_state_t callback) { init_local_state = callback; } + + bool HasStatisticsCallback() const { return statistics != nullptr; } + function_statistics_t GetStatisticsCallback() const { return statistics; } + void SetStatisticsCallback(function_statistics_t callback) { statistics = callback; } + + bool HasModifiedDatabasesCallback() const { return get_modified_databases != nullptr; } + get_modified_databases_t GetModifiedDatabasesCallback() const { return get_modified_databases; } + void SetModifiedDatabasesCallback(get_modified_databases_t callback) { get_modified_databases = callback; } + + bool HasSerializationCallbacks() const { return serialize != nullptr && deserialize != nullptr; } + void SetSerializeCallback(function_serialize_t callback) { serialize = callback; } + void SetDeserializeCallback(function_deserialize_t callback) { deserialize = callback; } + function_serialize_t GetSerializeCallback() const { return serialize; } + function_deserialize_t GetDeserializeCallback() const { return deserialize; } + // clang-format on + + bool HasExtraFunctionInfo() const { + return function_info != nullptr; + } + ScalarFunctionInfo &GetExtraFunctionInfo() const { + D_ASSERT(function_info.get()); + return *function_info; + } + void SetExtraFunctionInfo(shared_ptr info) { + function_info = std::move(info); + } + template + void SetExtraFunctionInfo(ARGS &&... args) { + function_info = make_shared_ptr(std::forward(args)...); + } + +public: //! The main scalar function to execute scalar_function_t function; //! The bind function (if any) @@ -164,6 +221,7 @@ class ScalarFunction : public BaseScalarFunction { // NOLINT: work-around bug in //! Additional function info, passed to the bind shared_ptr function_info; +public: DUCKDB_API bool operator==(const ScalarFunction &rhs) const; DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; diff --git a/src/duckdb/src/include/duckdb/function/table/arrow.hpp b/src/duckdb/src/include/duckdb/function/table/arrow.hpp index a596c86e9..c758f2ed5 100644 --- a/src/duckdb/src/include/duckdb/function/table/arrow.hpp +++ b/src/duckdb/src/include/duckdb/function/table/arrow.hpp @@ -212,7 +212,7 @@ struct ArrowTableFunction { vector &return_types, vector &names); //! Actual conversion from Arrow to DuckDB static void ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, - DataChunk &output, idx_t start, bool arrow_scan_is_projected = true, + DataChunk &output, bool arrow_scan_is_projected = true, idx_t rowid_column_index = COLUMN_IDENTIFIER_ROW_ID); //! Get next scan state diff --git a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp index 02d38ec76..e553a6d41 100644 --- a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp @@ -23,8 +23,8 @@ class DirectFileReader : public BaseFileReader { bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; - void Scan(ClientContext &context, GlobalTableFunctionState &global_state, LocalTableFunctionState &local_state, - DataChunk &chunk) override; + AsyncResult Scan(ClientContext &context, GlobalTableFunctionState &global_state, + LocalTableFunctionState &local_state, DataChunk &chunk) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) override; string GetReaderType() const override { diff --git a/src/duckdb/src/include/duckdb/function/table/read_file.hpp b/src/duckdb/src/include/duckdb/function/table/read_file.hpp index 966fea5ef..e0d2a51c2 100644 --- a/src/duckdb/src/include/duckdb/function/table/read_file.hpp +++ b/src/duckdb/src/include/duckdb/function/table/read_file.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/multi_file/multi_file_function.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" #include "utf8proc_wrapper.hpp" namespace duckdb { @@ -29,55 +30,8 @@ struct ReadFileGlobalState : public GlobalTableFunctionState { shared_ptr file_list; vector column_ids; bool requires_file_open = false; -}; - -struct ReadBlobOperation { - static constexpr const char *NAME = "read_blob"; - static constexpr const char *FILE_TYPE = "blob"; - - static inline LogicalType TYPE() { - return LogicalType::BLOB; - } -}; - -struct ReadTextOperation { - static constexpr const char *NAME = "read_text"; - static constexpr const char *FILE_TYPE = "text"; - - static inline LogicalType TYPE() { - return LogicalType::VARCHAR; - } -}; -template -struct DirectMultiFileInfo : MultiFileReaderInterface { - static unique_ptr CreateInterface(ClientContext &context); - unique_ptr InitializeOptions(ClientContext &context, - optional_ptr info) override; - bool ParseCopyOption(ClientContext &context, const string &key, const vector &values, - BaseFileReaderOptions &options, vector &expected_names, - vector &expected_types) override; - bool ParseOption(ClientContext &context, const string &key, const Value &val, MultiFileOptions &file_options, - BaseFileReaderOptions &options) override; - unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, - unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, - MultiFileBindData &bind_data) override; - optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, - FileExpandResult expand_result) override; - unique_ptr InitializeGlobalState(ClientContext &context, MultiFileBindData &bind_data, - MultiFileGlobalState &global_state) override; - unique_ptr InitializeLocalState(ExecutionContext &, GlobalTableFunctionState &) override; - shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, - BaseUnionData &union_data, const MultiFileBindData &bind_data_p) override; - shared_ptr CreateReader(ClientContext &context, GlobalTableFunctionState &gstate, - const OpenFileInfo &file, idx_t file_idx, - const MultiFileBindData &bind_data) override; - shared_ptr CreateReader(ClientContext &context, const OpenFileInfo &file, - BaseFileReaderOptions &options, - const MultiFileOptions &file_options) override; - unique_ptr GetCardinality(const MultiFileBindData &bind_data, idx_t file_count) override; - FileGlobInput GetGlobInput() override; + unique_ptr stream; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp index e325b2f46..49c5e794c 100644 --- a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -47,6 +47,10 @@ struct DuckDBSchemasFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBConnectionCountFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBApproxDatabaseCountFun { static void RegisterFunction(BuiltinFunctions &set); }; diff --git a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp index df4c829da..22407fff5 100644 --- a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp +++ b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp @@ -28,6 +28,8 @@ struct TableScanBindData : public TableFunctionData { bool is_index_scan; //! Whether or not the table scan is for index creation. bool is_create_index; + //! In what order to scan the row groups + unique_ptr order_options; public: bool Equals(const FunctionData &other_p) const override { diff --git a/src/duckdb/src/include/duckdb/function/table_function.hpp b/src/duckdb/src/include/duckdb/function/table_function.hpp index f6c9cc55e..dc98f3732 100644 --- a/src/duckdb/src/include/duckdb/function/table_function.hpp +++ b/src/duckdb/src/include/duckdb/function/table_function.hpp @@ -14,10 +14,13 @@ #include "duckdb/function/function.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/storage/statistics/node_statistics.hpp" +#include "duckdb/storage/table/row_group_reorderer.hpp" #include "duckdb/common/column_index.hpp" #include "duckdb/common/table_column.hpp" +#include "duckdb/parallel/async_result.hpp" #include "duckdb/function/partition_stats.hpp" #include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/enums/order_preservation_type.hpp" #include @@ -34,6 +37,9 @@ class SampleOptions; struct MultiFileReader; struct OperatorPartitionData; struct OperatorPartitionInfo; +enum class OrderByColumnType; +enum class RowGroupOrderType; +enum class OrderByStatistics; struct TableFunctionInfo { DUCKDB_API virtual ~TableFunctionInfo(); @@ -158,13 +164,15 @@ struct TableFunctionInput { TableFunctionInput(optional_ptr bind_data_p, optional_ptr local_state_p, optional_ptr global_state_p) - : bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p) { + : bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p), async_result() { } public: optional_ptr bind_data; optional_ptr local_state; optional_ptr global_state; + AsyncResult async_result {}; + AsyncResultsExecutionMode results_execution_mode {AsyncResultsExecutionMode::SYNCHRONOUS}; }; struct TableFunctionPartitionInput { @@ -324,19 +332,53 @@ typedef virtual_column_map_t (*table_function_get_virtual_columns_t)(ClientConte typedef vector (*table_function_get_row_id_columns)(ClientContext &context, optional_ptr bind_data); +typedef void (*table_function_set_scan_order)(unique_ptr order_options, + optional_ptr bind_data); + //! When to call init_global to initialize the table function enum class TableFunctionInitialization { INITIALIZE_ON_EXECUTE, INITIALIZE_ON_SCHEDULE }; class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-around bug in clang-tidy public: + DUCKDB_API TableFunction(); + // Overloads taking table_function_t DUCKDB_API - TableFunction(string name, vector arguments, table_function_t function, + TableFunction(string name, const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); DUCKDB_API TableFunction(const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); - DUCKDB_API TableFunction(); + // Overloads taking std::nullptr + DUCKDB_API + TableFunction(string name, const vector &arguments, std::nullptr_t function, + table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, + table_function_init_local_t init_local = nullptr); + DUCKDB_API + TableFunction(const vector &arguments, std::nullptr_t function, table_function_bind_t bind = nullptr, + table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); + + bool HasBindCallback() const { + return bind != nullptr; + } + table_function_bind_t GetBindCallback() const { + return bind; + } + bool HasSerializationCallbacks() const { + return serialize != nullptr && deserialize != nullptr; + } + void SetSerializeCallback(table_function_serialize_t callback) { + serialize = callback; + } + void SetDeserializeCallback(table_function_deserialize_t callback) { + deserialize = callback; + } + table_function_serialize_t GetSerializeCallback() const { + return serialize; + } + table_function_deserialize_t GetDeserializeCallback() const { + return deserialize; + } //! Bind function //! This function is used for determining the return type of a table producing function and returning bind data @@ -404,6 +446,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou table_function_get_virtual_columns_t get_virtual_columns; //! (Optional) returns a list of row id columns table_function_get_row_id_columns get_row_id_columns; + //! (Optional) sets the order to scan the row groups in + table_function_set_scan_order set_scan_order; table_function_serialize_t serialize; table_function_deserialize_t deserialize; @@ -425,6 +469,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou bool late_materialization; //! Additional function info, passed to the bind shared_ptr function_info; + //! The order preservation type of the table function + OrderPreservationType order_preservation_type = OrderPreservationType::INSERTION_ORDER; //! When to call init_global //! By default init_global is called when the pipeline is ready for execution @@ -432,6 +478,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou TableFunctionInitialization global_initialization = TableFunctionInitialization::INITIALIZE_ON_EXECUTE; DUCKDB_API bool Equal(const TableFunction &rhs) const; + DUCKDB_API bool operator==(const TableFunction &rhs) const; + DUCKDB_API bool operator!=(const TableFunction &rhs) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/udf_function.hpp b/src/duckdb/src/include/duckdb/function/udf_function.hpp index 571a49af4..3b23445e6 100644 --- a/src/duckdb/src/include/duckdb/function/udf_function.hpp +++ b/src/duckdb/src/include/duckdb/function/udf_function.hpp @@ -123,10 +123,9 @@ struct UDFWrapper { aggregate_combine_t combine, aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr) { - AggregateFunction aggr_function(name, arguments, return_type, state_size, initialize, update, combine, finalize, simple_update, bind, destructor); - aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + aggr_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return aggr_function; } diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp new file mode 100644 index 000000000..28f23d71e --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/serializer/varint.hpp" + +namespace duckdb { + +struct VariantNormalizerState { +public: + VariantNormalizerState(idx_t result_row, VariantVectorData &source, OrderedOwningStringMap &dictionary, + SelectionVector &keys_selvec); + +public: + data_ptr_t GetDestination(); + uint32_t GetOrCreateIndex(const string_t &key); + +public: + uint32_t keys_size = 0; + uint32_t children_size = 0; + uint32_t values_size = 0; + uint32_t blob_size = 0; + + VariantVectorData &source; + OrderedOwningStringMap &dictionary; + SelectionVector &keys_selvec; + + uint64_t keys_offset; + uint64_t children_offset; + ValidityMask &keys_index_validity; + + data_ptr_t blob_data; + uint8_t *type_ids; + uint32_t *byte_offsets; + uint32_t *values_indexes; + uint32_t *keys_indexes; +}; + +struct VariantNormalizer { + using result_type = void; + + static void VisitNull(VariantNormalizerState &state); + static void VisitBoolean(bool val, VariantNormalizerState &state); + static void VisitMetadata(VariantLogicalType type_id, VariantNormalizerState &state); + + template + static void VisitInteger(T val, VariantNormalizerState &state) { + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + static void VisitFloat(float val, VariantNormalizerState &state); + static void VisitDouble(double val, VariantNormalizerState &state); + static void VisitUUID(hugeint_t val, VariantNormalizerState &state); + static void VisitDate(date_t val, VariantNormalizerState &state); + static void VisitInterval(interval_t val, VariantNormalizerState &state); + static void VisitTime(dtime_t val, VariantNormalizerState &state); + static void VisitTimeNanos(dtime_ns_t val, VariantNormalizerState &state); + static void VisitTimeTZ(dtime_tz_t val, VariantNormalizerState &state); + static void VisitTimestampSec(timestamp_sec_t val, VariantNormalizerState &state); + static void VisitTimestampMs(timestamp_ms_t val, VariantNormalizerState &state); + static void VisitTimestamp(timestamp_t val, VariantNormalizerState &state); + static void VisitTimestampNanos(timestamp_ns_t val, VariantNormalizerState &state); + static void VisitTimestampTZ(timestamp_tz_t val, VariantNormalizerState &state); + + static void VisitString(const string_t &str, VariantNormalizerState &state); + static void VisitBlob(const string_t &blob, VariantNormalizerState &state); + static void VisitBignum(const string_t &bignum, VariantNormalizerState &state); + static void VisitGeometry(const string_t &geom, VariantNormalizerState &state); + static void VisitBitstring(const string_t &bits, VariantNormalizerState &state); + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, VariantNormalizerState &state) { + state.blob_size += VarintEncode(width, state.GetDestination()); + state.blob_size += VarintEncode(scale, state.GetDestination()); + Store(val, state.GetDestination()); + state.blob_size += sizeof(T); + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state); + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state); + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantNormalizerState &state); + +public: + static void Normalize(Vector &input, Vector &output, idx_t count); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp new file mode 100644 index 000000000..b30438b84 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp @@ -0,0 +1,117 @@ +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/string_map_set.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +namespace duckdb { + +struct VariantColumnStatsData { +public: + explicit VariantColumnStatsData(idx_t index) : index(index) { + } + +public: + void SetType(VariantLogicalType type); + +public: + //! The index in the 'columns' of the VariantShreddingStats + idx_t index; + //! Count of each variant type encountered + idx_t type_counts[static_cast(VariantLogicalType::ENUM_SIZE)] = {0}; + uint32_t decimal_width; + uint32_t decimal_scale; + bool decimal_consistent = false; + + idx_t total_count = 0; + //! indices into the top-level 'columns' vector where the stats for the field/element live + case_insensitive_map_t field_stats; + idx_t element_stats = DConstants::INVALID_INDEX; +}; + +struct VariantShreddingStats { +public: + VariantShreddingStats() { + columns.emplace_back(0); + } + +public: + VariantColumnStatsData &GetOrCreateElement(idx_t parent_index); + VariantColumnStatsData &GetOrCreateField(idx_t parent_index, const string &name); + + VariantColumnStatsData &GetColumnStats(idx_t index); + const VariantColumnStatsData &GetColumnStats(idx_t index) const; + +public: + void Update(Vector &input, idx_t count); + LogicalType GetShreddedType() const; + +private: + bool GetShreddedTypeInternal(const VariantColumnStatsData &column, LogicalType &out_type) const; + +private: + //! Nested type analysis + vector columns; +}; + +struct VariantShredding { +public: + VariantShredding() { + } + virtual ~VariantShredding() = default; + +public: + static LogicalType GetUnshreddedType() { + return LogicalType::STRUCT(StructType::GetChildTypes(LogicalType::VARIANT())); + } + +public: + virtual void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) = 0; + +protected: + void WriteTypedValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, idx_t count); + +private: + void WriteTypedObjectValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, idx_t count); + void WriteTypedArrayValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, idx_t count); + void WriteTypedPrimitiveValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count); +}; + +struct VariantShreddingState { +public: + explicit VariantShreddingState(const LogicalType &type, idx_t total_count); + virtual ~VariantShreddingState() { + } + +public: + bool ValueIsShredded(UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index); + void SetShredded(uint32_t row, uint32_t values_index, uint32_t result_idx); + case_insensitive_string_set_t ObjectFields(); + virtual const unordered_set &GetVariantTypes() = 0; + +public: + //! The type the field is shredded on + const LogicalType &type; + //! row that is shredded + SelectionVector shredded_sel; + //! 'values_index' of the shredded value + SelectionVector values_index_sel; + //! result row of the shredded value + SelectionVector result_sel; + //! The amount of rows that are shredded on + idx_t count = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_value_convert.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_value_convert.hpp new file mode 100644 index 000000000..3b83418fa --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/variant/variant_value_convert.hpp @@ -0,0 +1,140 @@ +#pragma once + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +struct ValueConverter { + using result_type = Value; + + static Value VisitNull() { + return Value(LogicalType::SQLNULL); + } + + static Value VisitBoolean(bool val) { + return Value::BOOLEAN(val); + } + + template + static Value VisitInteger(T val) { + throw InternalException("ValueConverter::VisitInteger not implemented!"); + } + + static Value VisitTime(dtime_t val) { + return Value::TIME(val); + } + + static Value VisitTimeNanos(dtime_ns_t val) { + return Value::TIME_NS(val); + } + + static Value VisitTimeTZ(dtime_tz_t val) { + return Value::TIMETZ(val); + } + + static Value VisitTimestampSec(timestamp_sec_t val) { + return Value::TIMESTAMPSEC(val); + } + + static Value VisitTimestampMs(timestamp_ms_t val) { + return Value::TIMESTAMPMS(val); + } + + static Value VisitTimestamp(timestamp_t val) { + return Value::TIMESTAMP(val); + } + + static Value VisitTimestampNanos(timestamp_ns_t val) { + return Value::TIMESTAMPNS(val); + } + + static Value VisitTimestampTZ(timestamp_tz_t val) { + return Value::TIMESTAMPTZ(val); + } + + static Value VisitFloat(float val) { + return Value::FLOAT(val); + } + static Value VisitDouble(double val) { + return Value::DOUBLE(val); + } + static Value VisitUUID(hugeint_t val) { + return Value::UUID(val); + } + static Value VisitDate(date_t val) { + return Value::DATE(val); + } + static Value VisitInterval(interval_t val) { + return Value::INTERVAL(val); + } + + static Value VisitString(const string_t &str) { + return Value(str); + } + static Value VisitBlob(const string_t &str) { + return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBignum(const string_t &str) { + return Value::BIGNUM(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitGeometry(const string_t &str) { + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + static Value VisitBitstring(const string_t &str) { + return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); + } + + template + static Value VisitDecimal(T val, uint32_t width, uint32_t scale) { + if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else if (std::is_same::value) { + return Value::DECIMAL(val, static_cast(width), static_cast(scale)); + } else { + throw InternalException("Unhandled decimal type"); + } + } + + static Value VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto array_items = VariantVisitor::VisitArrayItems(variant, row, nested_data); + return Value::LIST(LogicalType::VARIANT(), std::move(array_items)); + } + + static Value VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data) { + auto object_children = VariantVisitor::VisitObjectItems(variant, row, nested_data); + return Value::STRUCT(std::move(object_children)); + } + + static Value VisitDefault(VariantLogicalType type_id, const_data_ptr_t) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); + } +}; + +template <> +Value ValueConverter::VisitInteger(int8_t val); +template <> +Value ValueConverter::VisitInteger(int16_t val); +template <> +Value ValueConverter::VisitInteger(int32_t val); +template <> +Value ValueConverter::VisitInteger(int64_t val); +template <> +Value ValueConverter::VisitInteger(hugeint_t val); +template <> +Value ValueConverter::VisitInteger(uint8_t val); +template <> +Value ValueConverter::VisitInteger(uint16_t val); +template <> +Value ValueConverter::VisitInteger(uint32_t val); +template <> +Value ValueConverter::VisitInteger(uint64_t val); +template <> +Value ValueConverter::VisitInteger(uhugeint_t val); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_aggregate_states.hpp b/src/duckdb/src/include/duckdb/function/window/window_aggregate_states.hpp index 1382a522c..18f39cf39 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_aggregate_states.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_aggregate_states.hpp @@ -13,7 +13,7 @@ namespace duckdb { struct WindowAggregateStates { - explicit WindowAggregateStates(const AggregateObject &aggr); + WindowAggregateStates(ClientContext &client, const AggregateObject &aggr); ~WindowAggregateStates() { Destroy(); } @@ -34,13 +34,14 @@ struct WindowAggregateStates { //! Initialise all the states void Initialize(idx_t count); //! Combine the states into the target - void Combine(WindowAggregateStates &target, - AggregateCombineType combine_type = AggregateCombineType::PRESERVE_INPUT); + void Combine(WindowAggregateStates &target); //! Finalize the states into an output vector void Finalize(Vector &result); //! Destroy the states void Destroy(); + //! The context to use for memory etc. + ClientContext &client; //! A description of the aggregator const AggregateObject aggr; //! The size of each state diff --git a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp index 11c724d9b..4f007d83b 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp @@ -89,7 +89,6 @@ struct WindowInputExpression { }; struct WindowBoundariesState { - static bool HasPrecedingRange(const BoundWindowExpression &wexpr); static bool HasFollowingRange(const BoundWindowExpression &wexpr); static WindowBoundsSet GetWindowBounds(const BoundWindowExpression &wexpr); diff --git a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp index 95cf0534f..2dae27c6a 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp @@ -190,7 +190,6 @@ class WindowCollectionChunkScanner { template static void WindowDeltaScanner(ColumnDataCollection &collection, idx_t block_begin, idx_t block_end, const vector &scan_cols, const idx_t key_count, OP operation) { - // Stop if there is no work to do if (!collection.Count()) { return; diff --git a/src/duckdb/src/include/duckdb/logging/log_manager.hpp b/src/duckdb/src/include/duckdb/logging/log_manager.hpp index 6ee88aeda..54f623a55 100644 --- a/src/duckdb/src/include/duckdb/logging/log_manager.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_manager.hpp @@ -21,7 +21,7 @@ class LogType; // - Creates Loggers with cached configuration // - Main sink for logs (either by logging directly into this, or by syncing a pre-cached set of log entries) // - Holds the log storage -class LogManager : public enable_shared_from_this { +class LogManager { friend class ThreadSafeLogger; friend class ThreadLocalLogger; friend class MutableLogger; diff --git a/src/duckdb/src/include/duckdb/logging/log_type.hpp b/src/duckdb/src/include/duckdb/logging/log_type.hpp index 23d901c4e..a5e55bbd3 100644 --- a/src/duckdb/src/include/duckdb/logging/log_type.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_type.hpp @@ -20,6 +20,7 @@ class PhysicalOperator; class AttachedDatabase; class RowGroup; struct DataTableInfo; +enum class MetricType : uint8_t; //! Log types provide some structure to the formats that the different log messages can have //! For now, this holds a type that the VARCHAR value will be auto-cast into. @@ -32,6 +33,9 @@ class LogType { //! Construct a structured type LogType(const string &name_p, const LogLevel &level_p, LogicalType structured_type) : name(name_p), level(level_p), is_structured(true), type(std::move(structured_type)) { + if (!type.IsNested()) { + throw InternalException("LogType must be nested if the type is explicitly set"); + } } string name; @@ -106,6 +110,19 @@ class PhysicalOperatorLogType : public LogType { const vector> &info); }; +class MetricsLogType : public LogType { +public: + static constexpr const char *NAME = "Metrics"; + static constexpr LogLevel LEVEL = LogLevel::LOG_INFO; + + //! Construct the log type + MetricsLogType(); + + static LogicalType GetLogType(); + + static string ConstructLogMessage(const MetricType &type, const Value &value); +}; + class CheckpointLogType : public LogType { public: static constexpr const char *NAME = "Checkpoint"; @@ -121,11 +138,25 @@ class CheckpointLogType : public LogType { idx_t merge_count, idx_t target_count, idx_t merge_rows, idx_t row_start); //! Checkpoint static string ConstructLogMessage(const AttachedDatabase &db, DataTableInfo &table, idx_t segment_idx, - RowGroup &row_group); + RowGroup &row_group, idx_t row_group_start); private: static string CreateLog(const AttachedDatabase &db, DataTableInfo &table, const char *op, vector map_keys, vector map_values); }; +class TransactionLogType : public LogType { +public: + static constexpr const char *NAME = "Transaction"; + static constexpr LogLevel LEVEL = LogLevel::LOG_DEBUG; + + //! Construct the log type + TransactionLogType(); + + static LogicalType GetLogType(); + + static string ConstructLogMessage(const AttachedDatabase &db, const char *log_type, + transaction_t transaction_id = MAX_TRANSACTION_ID); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/logging/logger.hpp b/src/duckdb/src/include/duckdb/logging/logger.hpp index 048588958..0c132fe2c 100644 --- a/src/duckdb/src/include/duckdb/logging/logger.hpp +++ b/src/duckdb/src/include/duckdb/logging/logger.hpp @@ -40,7 +40,8 @@ struct FileHandle; #define DUCKDB_LOG_DEBUG(SOURCE, ...) \ DUCKDB_LOG_INTERNAL(SOURCE, DefaultLogType::NAME, LogLevel::LOG_DEBUG, __VA_ARGS__) #define DUCKDB_LOG_INFO(SOURCE, ...) DUCKDB_LOG_INTERNAL(SOURCE, DefaultLogType::NAME, LogLevel::LOG_INFO, __VA_ARGS__) -#define DUCKDB_LOG_WARN(SOURCE, ...) DUCKDB_LOG_INTERNAL(SOURCE, DefaultLogType::NAME, LogLevel::LOG_WARN, __VA_ARGS__) +#define DUCKDB_LOG_WARNING(SOURCE, ...) \ + DUCKDB_LOG_INTERNAL(SOURCE, DefaultLogType::NAME, LogLevel::LOG_WARNING, __VA_ARGS__) #define DUCKDB_LOG_ERROR(SOURCE, ...) \ DUCKDB_LOG_INTERNAL(SOURCE, DefaultLogType::NAME, LogLevel::LOG_ERROR, __VA_ARGS__) #define DUCKDB_LOG_FATAL(SOURCE, ...) \ diff --git a/src/duckdb/src/include/duckdb/logging/logging.hpp b/src/duckdb/src/include/duckdb/logging/logging.hpp index 36cc7126f..11d0760c7 100644 --- a/src/duckdb/src/include/duckdb/logging/logging.hpp +++ b/src/duckdb/src/include/duckdb/logging/logging.hpp @@ -22,7 +22,7 @@ enum class LogLevel : uint8_t { LOG_TRACE = 10, LOG_DEBUG = 20, LOG_INFO = 30, - LOG_WARN = 40, + LOG_WARNING = 40, LOG_ERROR = 50, LOG_FATAL = 60 }; diff --git a/src/duckdb/src/include/duckdb/main/appender.hpp b/src/duckdb/src/include/duckdb/main/appender.hpp index b32025cb0..b637717f8 100644 --- a/src/duckdb/src/include/duckdb/main/appender.hpp +++ b/src/duckdb/src/include/duckdb/main/appender.hpp @@ -48,6 +48,8 @@ class BaseAppender { AppenderType appender_type; //! The amount of rows after which the appender flushes automatically. idx_t flush_count = DEFAULT_FLUSH_COUNT; + //! Peak allocation threshold at which to flush the allocator when appender flushs chunk. + optional_idx flush_memory_threshold; protected: DUCKDB_API BaseAppender(Allocator &allocator, const AppenderType type); @@ -82,6 +84,8 @@ class BaseAppender { DUCKDB_API void Flush(); //! Flush the changes made by the appender and close it. The appender cannot be used after this point DUCKDB_API void Close(); + //! Clears any appended data (without flushing). + DUCKDB_API void Clear(); //! Returns the active types of the appender. const vector &GetActiveTypes() const; @@ -105,6 +109,9 @@ class BaseAppender { void InitializeChunk(); void FlushChunk(); + bool ShouldFlushChunk() const; + bool ShouldFlush() const; + template void AppendValueInternal(T value); template @@ -129,9 +136,11 @@ class BaseAppender { class Appender : public BaseAppender { public: DUCKDB_API Appender(Connection &con, const string &database_name, const string &schema_name, - const string &table_name); - DUCKDB_API Appender(Connection &con, const string &schema_name, const string &table_name); - DUCKDB_API Appender(Connection &con, const string &table_name); + const string &table_name, const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); + DUCKDB_API Appender(Connection &con, const string &schema_name, const string &table_name, + const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); + DUCKDB_API Appender(Connection &con, const string &table_name, + const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); DUCKDB_API ~Appender() override; public: @@ -160,7 +169,8 @@ class Appender : public BaseAppender { class QueryAppender : public BaseAppender { public: DUCKDB_API QueryAppender(Connection &con, string query, vector types, - vector names = vector(), string table_name = string()); + vector names = vector(), string table_name = string(), + const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); DUCKDB_API ~QueryAppender() override; private: @@ -185,7 +195,8 @@ class InternalAppender : public BaseAppender { public: DUCKDB_API InternalAppender(ClientContext &context, TableCatalogEntry &table, - const idx_t flush_count = DEFAULT_FLUSH_COUNT); + const idx_t flush_count = DEFAULT_FLUSH_COUNT, + const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); DUCKDB_API ~InternalAppender() override; protected: diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp index 7333d9adb..d75a7b922 100644 --- a/src/duckdb/src/include/duckdb/main/attached_database.hpp +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -34,14 +34,23 @@ enum class AttachedDatabaseType { enum class AttachVisibility { SHOWN, HIDDEN }; +//! DEFAULT is the standard ACID crash recovery mode. +//! NO_WAL_WRITES disables the WAL for the attached database, i.e., disabling the D in ACID. +//! Use this mode with caution, as it disables recovery from crashes for the file. +enum class RecoveryMode : uint8_t { DEFAULT = 0, NO_WAL_WRITES = 1 }; + class DatabaseFilePathManager; struct StoredDatabasePath { - StoredDatabasePath(DatabaseFilePathManager &manager, string path, const string &name); + StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path, const string &name); ~StoredDatabasePath(); + DatabaseManager &db_manager; DatabaseFilePathManager &manager; string path; + +public: + void OnDetach(); }; //! AttachOptions holds information about a database we plan to attach. These options are generalized, i.e., @@ -54,6 +63,8 @@ struct AttachOptions { //! Defaults to the access mode configured in the DBConfig, unless specified otherwise. AccessMode access_mode; + //! The recovery type of the database. + RecoveryMode recovery_mode = RecoveryMode::DEFAULT; //! The file format type. The default type is a duckdb database file, but other file formats are possible. string db_type; //! Set of remaining (key, value) options @@ -112,9 +123,13 @@ class AttachedDatabase : public CatalogEntry, public enable_shared_from_this parent_catalog; optional_ptr storage_extension; + RecoveryMode recovery_mode = RecoveryMode::DEFAULT; AttachVisibility visibility = AttachVisibility::SHOWN; bool is_initial_database = false; bool is_closed = false; diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp index c3ffadbe2..57ea17d6b 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/batched_buffered_data.hpp @@ -32,7 +32,7 @@ class BatchedBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::BATCHED; public: - explicit BatchedBufferedData(weak_ptr context); + explicit BatchedBufferedData(ClientContext &context); public: void Append(const DataChunk &chunk, idx_t batch); diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp index 0f32675ce..06a72b0f6 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/buffered_data.hpp @@ -28,7 +28,7 @@ class BufferedData { enum class Type { SIMPLE, BATCHED }; public: - BufferedData(Type type, weak_ptr context_p); + BufferedData(Type type, ClientContext &context); virtual ~BufferedData(); public: diff --git a/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp b/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp index 967cc1ab7..40a5a6ede 100644 --- a/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp +++ b/src/duckdb/src/include/duckdb/main/buffered_data/simple_buffered_data.hpp @@ -24,7 +24,7 @@ class SimpleBufferedData : public BufferedData { static constexpr const BufferedData::Type TYPE = BufferedData::Type::SIMPLE; public: - explicit SimpleBufferedData(weak_ptr context); + explicit SimpleBufferedData(ClientContext &context); ~SimpleBufferedData() override; public: diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp index 8307b70a3..3b736d88a 100644 --- a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp @@ -51,6 +51,8 @@ struct PreparedStatementWrapper { //! Map of name -> values case_insensitive_map_t values; unique_ptr statement; + bool success = true; + ErrorData error_data; }; struct ExtractStatementsWrapper { diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp new file mode 100644 index 000000000..f51947cf5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal_table.hpp @@ -0,0 +1,74 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/capi/capi_internal_table.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.h" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/function/table_function.hpp" + +namespace duckdb { + +// These need to be shared by both the table function API and the copy function API + +struct CTableFunctionInfo : public TableFunctionInfo { + ~CTableFunctionInfo() override { + if (extra_info && delete_callback) { + delete_callback(extra_info); + } + extra_info = nullptr; + delete_callback = nullptr; + } + + duckdb_table_function_bind_t bind = nullptr; + duckdb_table_function_init_t init = nullptr; + duckdb_table_function_init_t local_init = nullptr; + duckdb_table_function_t function = nullptr; + void *extra_info = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +struct CTableBindData : public TableFunctionData { + explicit CTableBindData(CTableFunctionInfo &info) : info(info) { + } + ~CTableBindData() override { + if (bind_data && delete_callback) { + delete_callback(bind_data); + } + bind_data = nullptr; + delete_callback = nullptr; + } + + CTableFunctionInfo &info; + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + unique_ptr stats; +}; + +struct CTableInternalBindInfo { + CTableInternalBindInfo(ClientContext &context, const vector ¶meters, + const named_parameter_map_t &named_parameters, vector &return_types, + vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) + : context(context), parameters(parameters), named_parameters(named_parameters), return_types(return_types), + names(names), bind_data(bind_data), function_info(function_info), success(true) { + } + + ClientContext &context; + + vector parameters; + named_parameter_map_t named_parameters; + + vector &return_types; + vector &names; + CTableBindData &bind_data; + CTableFunctionInfo &function_info; + bool success; + string error; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp b/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp index 2a9bf7644..34fdcf367 100644 --- a/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/cast/from_decimal.hpp @@ -20,22 +20,21 @@ bool CastDecimalCInternal(duckdb_result *source, RESULT_TYPE &result, idx_t col, auto &source_type = query_result->types[col]; auto width = duckdb::DecimalType::GetWidth(source_type); auto scale = duckdb::DecimalType::GetScale(source_type); - void *source_address = UnsafeFetchPtr(source, col, row); - + auto source_value = UnsafeFetch(source, col, row); CastParameters parameters; switch (source_type.InternalType()) { case duckdb::PhysicalType::INT16: - return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - result, parameters, width, scale); + return duckdb::TryCastFromDecimal::Operation(static_cast(source_value), result, + parameters, width, scale); case duckdb::PhysicalType::INT32: - return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - result, parameters, width, scale); + return duckdb::TryCastFromDecimal::Operation(static_cast(source_value), result, + parameters, width, scale); case duckdb::PhysicalType::INT64: - return duckdb::TryCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - result, parameters, width, scale); + return duckdb::TryCastFromDecimal::Operation(static_cast(source_value), result, + parameters, width, scale); case duckdb::PhysicalType::INT128: - return duckdb::TryCastFromDecimal::Operation( - UnsafeFetchFromPtr(source_address), result, parameters, width, scale); + return duckdb::TryCastFromDecimal::Operation(source_value, result, parameters, width, + scale); default: throw duckdb::InternalException("Unimplemented internal type for decimal"); } diff --git a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp index 2ce10061a..f3fee0d91 100644 --- a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp @@ -475,6 +475,7 @@ typedef struct { duckdb_state (*duckdb_appender_create_query)(duckdb_connection connection, const char *query, idx_t column_count, duckdb_logical_type *types, const char *table_name, const char **column_names, duckdb_appender *out_appender); + duckdb_state (*duckdb_appender_clear)(duckdb_appender appender); // New arrow interface functions duckdb_error_data (*duckdb_to_arrow_schema)(duckdb_arrow_options arrow_options, duckdb_logical_type *types, @@ -487,6 +488,76 @@ typedef struct { duckdb_arrow_converted_schema converted_schema, duckdb_data_chunk *out_chunk); void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); + // New functions for interacting with catalog entries + + duckdb_catalog (*duckdb_client_context_get_catalog)(duckdb_client_context context, const char *catalog_name); + const char *(*duckdb_catalog_get_type_name)(duckdb_catalog catalog); + duckdb_catalog_entry (*duckdb_catalog_get_entry)(duckdb_catalog catalog, duckdb_client_context context, + duckdb_catalog_entry_type entry_type, const char *schema_name, + const char *entry_name); + void (*duckdb_destroy_catalog)(duckdb_catalog *catalog); + duckdb_catalog_entry_type (*duckdb_catalog_entry_get_type)(duckdb_catalog_entry entry); + const char *(*duckdb_catalog_entry_get_name)(duckdb_catalog_entry entry); + void (*duckdb_destroy_catalog_entry)(duckdb_catalog_entry *entry); + // New configuration options functions + + duckdb_config_option (*duckdb_create_config_option)(); + void (*duckdb_destroy_config_option)(duckdb_config_option *option); + void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); + void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); + void (*duckdb_config_option_set_default_value)(duckdb_config_option option, duckdb_value default_value); + void (*duckdb_config_option_set_default_scope)(duckdb_config_option option, + duckdb_config_option_scope default_scope); + void (*duckdb_config_option_set_description)(duckdb_config_option option, const char *description); + duckdb_state (*duckdb_register_config_option)(duckdb_connection connection, duckdb_config_option option); + duckdb_value (*duckdb_client_context_get_config_option)(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); + // API to define custom copy functions + + duckdb_copy_function (*duckdb_create_copy_function)(); + void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); + void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + duckdb_state (*duckdb_register_copy_function)(duckdb_connection connection, duckdb_copy_function copy_function); + void (*duckdb_destroy_copy_function)(duckdb_copy_function *copy_function); + void (*duckdb_copy_function_set_bind)(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + void (*duckdb_copy_function_bind_set_error)(duckdb_copy_function_bind_info info, const char *error); + void *(*duckdb_copy_function_bind_get_extra_info)(duckdb_copy_function_bind_info info); + duckdb_client_context (*duckdb_copy_function_bind_get_client_context)(duckdb_copy_function_bind_info info); + idx_t (*duckdb_copy_function_bind_get_column_count)(duckdb_copy_function_bind_info info); + duckdb_logical_type (*duckdb_copy_function_bind_get_column_type)(duckdb_copy_function_bind_info info, + idx_t col_idx); + duckdb_value (*duckdb_copy_function_bind_get_options)(duckdb_copy_function_bind_info info); + void (*duckdb_copy_function_bind_set_bind_data)(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + void (*duckdb_copy_function_set_global_init)(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + void (*duckdb_copy_function_global_init_set_error)(duckdb_copy_function_global_init_info info, const char *error); + void *(*duckdb_copy_function_global_init_get_extra_info)(duckdb_copy_function_global_init_info info); + duckdb_client_context (*duckdb_copy_function_global_init_get_client_context)( + duckdb_copy_function_global_init_info info); + void *(*duckdb_copy_function_global_init_get_bind_data)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_global_init_set_global_state)(duckdb_copy_function_global_init_info info, + void *global_state, duckdb_delete_callback_t destructor); + const char *(*duckdb_copy_function_global_init_get_file_path)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_set_sink)(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function); + void (*duckdb_copy_function_sink_set_error)(duckdb_copy_function_sink_info info, const char *error); + void *(*duckdb_copy_function_sink_get_extra_info)(duckdb_copy_function_sink_info info); + duckdb_client_context (*duckdb_copy_function_sink_get_client_context)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_bind_data)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_global_state)(duckdb_copy_function_sink_info info); + void (*duckdb_copy_function_set_finalize)(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + void (*duckdb_copy_function_finalize_set_error)(duckdb_copy_function_finalize_info info, const char *error); + void *(*duckdb_copy_function_finalize_get_extra_info)(duckdb_copy_function_finalize_info info); + duckdb_client_context (*duckdb_copy_function_finalize_get_client_context)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_bind_data)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_global_state)(duckdb_copy_function_finalize_info info); + void (*duckdb_copy_function_set_copy_from_function)(duckdb_copy_function copy_function, + duckdb_table_function table_function); + idx_t (*duckdb_table_function_bind_get_result_column_count)(duckdb_bind_info info); + const char *(*duckdb_table_function_bind_get_result_column_name)(duckdb_bind_info info, idx_t col_idx); + duckdb_logical_type (*duckdb_table_function_bind_get_result_column_type)(duckdb_bind_info info, idx_t col_idx); // New functions for duckdb error data duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -521,6 +592,16 @@ typedef struct { int64_t (*duckdb_file_handle_tell)(duckdb_file_handle file_handle); duckdb_state (*duckdb_file_handle_sync)(duckdb_file_handle file_handle); int64_t (*duckdb_file_handle_size)(duckdb_file_handle file_handle); + // API to register a custom log storage. + + duckdb_log_storage (*duckdb_create_log_storage)(); + void (*duckdb_destroy_log_storage)(duckdb_log_storage *log_storage); + void (*duckdb_log_storage_set_write_log_entry)(duckdb_log_storage log_storage, + duckdb_logger_write_log_entry_t function); + void (*duckdb_log_storage_set_extra_data)(duckdb_log_storage log_storage, void *extra_data, + duckdb_delete_callback_t delete_callback); + void (*duckdb_log_storage_set_name)(duckdb_log_storage log_storage, const char *name); + duckdb_state (*duckdb_register_log_storage)(duckdb_database database, duckdb_log_storage log_storage); // New functions around the client context idx_t (*duckdb_client_context_get_connection_id)(duckdb_client_context context); @@ -554,6 +635,11 @@ typedef struct { // New string functions that are added char *(*duckdb_value_to_string)(duckdb_value value); + // New functions around the table description + + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); // New functions around table function binding void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -993,11 +1079,64 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_append_default_to_chunk = duckdb_append_default_to_chunk; result.duckdb_appender_error_data = duckdb_appender_error_data; result.duckdb_appender_create_query = duckdb_appender_create_query; + result.duckdb_appender_clear = duckdb_appender_clear; result.duckdb_to_arrow_schema = duckdb_to_arrow_schema; result.duckdb_data_chunk_to_arrow = duckdb_data_chunk_to_arrow; result.duckdb_schema_from_arrow = duckdb_schema_from_arrow; result.duckdb_data_chunk_from_arrow = duckdb_data_chunk_from_arrow; result.duckdb_destroy_arrow_converted_schema = duckdb_destroy_arrow_converted_schema; + result.duckdb_client_context_get_catalog = duckdb_client_context_get_catalog; + result.duckdb_catalog_get_type_name = duckdb_catalog_get_type_name; + result.duckdb_catalog_get_entry = duckdb_catalog_get_entry; + result.duckdb_destroy_catalog = duckdb_destroy_catalog; + result.duckdb_catalog_entry_get_type = duckdb_catalog_entry_get_type; + result.duckdb_catalog_entry_get_name = duckdb_catalog_entry_get_name; + result.duckdb_destroy_catalog_entry = duckdb_destroy_catalog_entry; + result.duckdb_create_config_option = duckdb_create_config_option; + result.duckdb_destroy_config_option = duckdb_destroy_config_option; + result.duckdb_config_option_set_name = duckdb_config_option_set_name; + result.duckdb_config_option_set_type = duckdb_config_option_set_type; + result.duckdb_config_option_set_default_value = duckdb_config_option_set_default_value; + result.duckdb_config_option_set_default_scope = duckdb_config_option_set_default_scope; + result.duckdb_config_option_set_description = duckdb_config_option_set_description; + result.duckdb_register_config_option = duckdb_register_config_option; + result.duckdb_client_context_get_config_option = duckdb_client_context_get_config_option; + result.duckdb_create_copy_function = duckdb_create_copy_function; + result.duckdb_copy_function_set_name = duckdb_copy_function_set_name; + result.duckdb_copy_function_set_extra_info = duckdb_copy_function_set_extra_info; + result.duckdb_register_copy_function = duckdb_register_copy_function; + result.duckdb_destroy_copy_function = duckdb_destroy_copy_function; + result.duckdb_copy_function_set_bind = duckdb_copy_function_set_bind; + result.duckdb_copy_function_bind_set_error = duckdb_copy_function_bind_set_error; + result.duckdb_copy_function_bind_get_extra_info = duckdb_copy_function_bind_get_extra_info; + result.duckdb_copy_function_bind_get_client_context = duckdb_copy_function_bind_get_client_context; + result.duckdb_copy_function_bind_get_column_count = duckdb_copy_function_bind_get_column_count; + result.duckdb_copy_function_bind_get_column_type = duckdb_copy_function_bind_get_column_type; + result.duckdb_copy_function_bind_get_options = duckdb_copy_function_bind_get_options; + result.duckdb_copy_function_bind_set_bind_data = duckdb_copy_function_bind_set_bind_data; + result.duckdb_copy_function_set_global_init = duckdb_copy_function_set_global_init; + result.duckdb_copy_function_global_init_set_error = duckdb_copy_function_global_init_set_error; + result.duckdb_copy_function_global_init_get_extra_info = duckdb_copy_function_global_init_get_extra_info; + result.duckdb_copy_function_global_init_get_client_context = duckdb_copy_function_global_init_get_client_context; + result.duckdb_copy_function_global_init_get_bind_data = duckdb_copy_function_global_init_get_bind_data; + result.duckdb_copy_function_global_init_set_global_state = duckdb_copy_function_global_init_set_global_state; + result.duckdb_copy_function_global_init_get_file_path = duckdb_copy_function_global_init_get_file_path; + result.duckdb_copy_function_set_sink = duckdb_copy_function_set_sink; + result.duckdb_copy_function_sink_set_error = duckdb_copy_function_sink_set_error; + result.duckdb_copy_function_sink_get_extra_info = duckdb_copy_function_sink_get_extra_info; + result.duckdb_copy_function_sink_get_client_context = duckdb_copy_function_sink_get_client_context; + result.duckdb_copy_function_sink_get_bind_data = duckdb_copy_function_sink_get_bind_data; + result.duckdb_copy_function_sink_get_global_state = duckdb_copy_function_sink_get_global_state; + result.duckdb_copy_function_set_finalize = duckdb_copy_function_set_finalize; + result.duckdb_copy_function_finalize_set_error = duckdb_copy_function_finalize_set_error; + result.duckdb_copy_function_finalize_get_extra_info = duckdb_copy_function_finalize_get_extra_info; + result.duckdb_copy_function_finalize_get_client_context = duckdb_copy_function_finalize_get_client_context; + result.duckdb_copy_function_finalize_get_bind_data = duckdb_copy_function_finalize_get_bind_data; + result.duckdb_copy_function_finalize_get_global_state = duckdb_copy_function_finalize_get_global_state; + result.duckdb_copy_function_set_copy_from_function = duckdb_copy_function_set_copy_from_function; + result.duckdb_table_function_bind_get_result_column_count = duckdb_table_function_bind_get_result_column_count; + result.duckdb_table_function_bind_get_result_column_name = duckdb_table_function_bind_get_result_column_name; + result.duckdb_table_function_bind_get_result_column_type = duckdb_table_function_bind_get_result_column_type; result.duckdb_create_error_data = duckdb_create_error_data; result.duckdb_destroy_error_data = duckdb_destroy_error_data; result.duckdb_error_data_error_type = duckdb_error_data_error_type; @@ -1023,6 +1162,12 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_file_handle_tell = duckdb_file_handle_tell; result.duckdb_file_handle_sync = duckdb_file_handle_sync; result.duckdb_file_handle_size = duckdb_file_handle_size; + result.duckdb_create_log_storage = duckdb_create_log_storage; + result.duckdb_destroy_log_storage = duckdb_destroy_log_storage; + result.duckdb_log_storage_set_write_log_entry = duckdb_log_storage_set_write_log_entry; + result.duckdb_log_storage_set_extra_data = duckdb_log_storage_set_extra_data; + result.duckdb_log_storage_set_name = duckdb_log_storage_set_name; + result.duckdb_register_log_storage = duckdb_register_log_storage; result.duckdb_client_context_get_connection_id = duckdb_client_context_get_connection_id; result.duckdb_destroy_client_context = duckdb_destroy_client_context; result.duckdb_connection_get_client_context = duckdb_connection_get_client_context; @@ -1044,6 +1189,8 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_scalar_function_bind_get_argument = duckdb_scalar_function_bind_get_argument; result.duckdb_scalar_function_set_bind_data_copy = duckdb_scalar_function_set_bind_data_copy; result.duckdb_value_to_string = duckdb_value_to_string; + result.duckdb_table_description_get_column_count = duckdb_table_description_get_column_count; + result.duckdb_table_description_get_column_type = duckdb_table_description_get_column_type; result.duckdb_table_function_get_client_context = duckdb_table_function_get_client_context; result.duckdb_create_map_value = duckdb_create_map_value; result.duckdb_create_union_value = duckdb_create_union_value; diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp index f9e673b19..2485ffdf4 100644 --- a/src/duckdb/src/include/duckdb/main/client_config.hpp +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -40,7 +40,7 @@ struct ClientConfig { string profiler_save_location; //! The custom settings for the profiler //! (empty = use the default settings) - profiler_settings_t profiler_settings = ProfilingInfo::DefaultSettings(); + profiler_settings_t profiler_settings = MetricsUtils::GetDefaultMetrics(); //! Allows suppressing profiler output, even if enabled. We turn on the profiler on all test runs but don't want //! to output anything @@ -121,7 +121,7 @@ struct ClientConfig { bool AnyVerification() const; - void SetUserVariable(const string &name, Value value); + void SetUserVariable(const String &name, Value value); bool GetUserVariable(const string &name, Value &result); void ResetUserVariable(const String &name); diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index ddb14518c..00ce7e2d2 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -28,6 +28,7 @@ #include "duckdb/main/table_description.hpp" #include "duckdb/planner/expression/bound_parameter_data.hpp" #include "duckdb/transaction/transaction_context.hpp" +#include "duckdb/main/query_parameters.hpp" namespace duckdb { @@ -56,8 +57,8 @@ class RegisteredStateManager; struct PendingQueryParameters { //! Prepared statement parameters (if any) optional_ptr> parameters; - //! Whether a stream result should be allowed - bool allow_stream_result = false; + //! Whether a stream/buffer-managed result should be allowed + QueryParameters query_parameters; }; //! The ClientContext holds information relevant to the current client session @@ -96,6 +97,7 @@ class ClientContext : public enable_shared_from_this { //! Interrupt execution of a query DUCKDB_API void Interrupt(); + DUCKDB_API bool IsInterrupted() const; DUCKDB_API void CancelTransaction(); //! Enable query profiling @@ -106,22 +108,24 @@ class ClientContext : public enable_shared_from_this { //! Issue a query, returning a QueryResult. The QueryResult can be either a StreamQueryResult or a //! MaterializedQueryResult. The StreamQueryResult will only be returned in the case of a successful SELECT //! statement. - DUCKDB_API unique_ptr Query(const string &query, bool allow_stream_result); - DUCKDB_API unique_ptr Query(unique_ptr statement, bool allow_stream_result); + DUCKDB_API unique_ptr Query(const string &query, QueryParameters query_parameters); + DUCKDB_API unique_ptr Query(unique_ptr statement, QueryParameters query_parameters); //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain //! a single statement. - DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result); + DUCKDB_API unique_ptr PendingQuery(const string &query, QueryParameters query_parameters); //! Issues a query to the database and returns a Pending Query Result DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - bool allow_stream_result); + QueryParameters query_parameters); //! Create a pending query with a list of parameters DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, case_insensitive_map_t &values, - bool allow_stream_result); - DUCKDB_API unique_ptr - PendingQuery(const string &query, case_insensitive_map_t &values, bool allow_stream_result); + QueryParameters query_parameters); + DUCKDB_API unique_ptr PendingQuery(const string &query, + case_insensitive_map_t &values, + QueryParameters query_parameters); + DUCKDB_API unique_ptr PendingQuery(const string &query, PendingQueryParameters parameters); //! Destroy the client context DUCKDB_API void Destroy(); @@ -147,7 +151,7 @@ class ClientContext : public enable_shared_from_this { //! Execute a relation DUCKDB_API unique_ptr PendingQuery(const shared_ptr &relation, - bool allow_stream_result); + QueryParameters query_parameters); DUCKDB_API unique_ptr Execute(const shared_ptr &relation); //! Prepare a query @@ -165,9 +169,10 @@ class ClientContext : public enable_shared_from_this { //! Execute a prepared statement with the given name and set of parameters //! It is possible that the prepared statement will be re-bound. This will generally happen if the catalog is //! modified in between the prepared statement being bound and the prepared statement being run. - DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, - case_insensitive_map_t &values, - bool allow_stream_result = true); + DUCKDB_API unique_ptr + Execute(const string &query, shared_ptr &prepared, + case_insensitive_map_t &values, + QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, const PendingQueryParameters ¶meters); @@ -238,7 +243,7 @@ class ClientContext : public enable_shared_from_this { //! Perform aggressive query verification of a SELECT statement. Only called when query_verification_enabled is //! true. ErrorData VerifyQuery(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values = nullptr); + PendingQueryParameters parameters); void InitialCleanup(ClientContextLock &lock); //! Internal clean up, does not lock. Caller must hold the context_lock. @@ -259,15 +264,14 @@ class ClientContext : public enable_shared_from_this { //! Internally prepare a SQL statement. Caller must hold the context_lock. shared_ptr CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values = nullptr, + PendingQueryParameters parameters, PreparedStatementMode mode = PreparedStatementMode::PREPARE_ONLY); unique_ptr PendingStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, const PendingQueryParameters ¶meters); unique_ptr RunStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, bool allow_stream_result, - optional_ptr> params, - bool verify = true); + unique_ptr statement, + const PendingQueryParameters ¶meters, bool verify = true); unique_ptr PrepareInternal(ClientContextLock &lock, unique_ptr statement); void LogQueryInternal(ClientContextLock &lock, const string &query); @@ -292,7 +296,7 @@ class ClientContext : public enable_shared_from_this { const PendingQueryParameters ¶meters); unique_ptr PendingQueryInternal(ClientContextLock &, const shared_ptr &relation, - bool allow_stream_result); + QueryParameters query_parameters); void RebindPreparedStatement(ClientContextLock &lock, const string &query, shared_ptr &prepared, const PendingQueryParameters ¶meters); @@ -300,9 +304,9 @@ class ClientContext : public enable_shared_from_this { template unique_ptr ErrorResult(ErrorData error, const string &query = string()); - shared_ptr - CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values); + shared_ptr CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + PendingQueryParameters parameters); SettingLookupResult TryGetCurrentSettingInternal(const string &key, Value &result) const; @@ -337,6 +341,8 @@ class QueryContext { } QueryContext(optional_ptr context) : context(context) { // NOLINT: allow implicit construction } + QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction + } public: bool Valid() const { diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp index 9a685f560..d2be0272b 100644 --- a/src/duckdb/src/include/duckdb/main/config.hpp +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -40,6 +40,7 @@ namespace duckdb { +class BlockAllocator; class BufferManager; class BufferPool; class CastFunctionSet; @@ -110,6 +111,8 @@ struct DBConfigOptions { #else bool autoinstall_known_extensions = false; #endif + //! Setting for the parser override registered by extensions. Allowed options: "default, "fallback", "strict" + string allow_parser_override_extension = "default"; //! Override for the default extension repository string custom_extension_repo = ""; //! Override for the default autoload extension repository @@ -163,12 +166,18 @@ struct DBConfigOptions { CompressionType force_compression = CompressionType::COMPRESSION_AUTO; //! Force a specific bitpacking mode to be used when using the bitpacking compression method BitpackingMode force_bitpacking_mode = BitpackingMode::AUTO; + //! Force a specific schema for VARIANT shredding + LogicalType force_variant_shredding = LogicalType::INVALID; + //! Minimum size of a rowgroup to enable VARIANT shredding, -1 to disable + int64_t variant_minimum_shredding_size = 30000; //! Database configuration variables as controlled by SET case_insensitive_map_t set_variables; //! Database configuration variable default values; case_insensitive_map_t set_variable_defaults; //! Directory to store extension binaries in string extension_directory; + //! Additional directories to store extension binaries in + vector extension_directories; //! Whether unsigned extensions should be loaded bool allow_unsigned_extensions = false; //! Whether community extensions should be loaded @@ -217,6 +226,8 @@ struct DBConfigOptions { #endif //! Whether to pin threads to cores (linux only, default AUTOMATIC: on when there are more than 64 cores) ThreadPinMode pin_threads = ThreadPinMode::AUTO; + //! Physical memory that the block allocator is allowed to use (this memory is never freed and cannot be reduced) + idx_t block_allocator_size = 0; bool operator==(const DBConfigOptions &other) const; }; @@ -244,6 +255,8 @@ struct DBConfig { unique_ptr secret_manager; //! The allocator used by the system unique_ptr allocator; + //! The block allocator used by the system + unique_ptr block_allocator; //! Database configuration options DBConfigOptions options; //! Extensions made to the parser @@ -289,6 +302,7 @@ struct DBConfig { DUCKDB_API void AddExtensionOption(const string &name, string description, LogicalType parameter, const Value &default_value = Value(), set_option_callback_t function = nullptr, SetScope default_scope = SetScope::SESSION); + DUCKDB_API bool HasExtensionOption(const string &name); //! Fetch an option by index. Returns a pointer to the option, or nullptr if out of range DUCKDB_API static optional_ptr GetOptionByIndex(idx_t index); //! Fetcha n alias by index, or nullptr if out of range @@ -300,7 +314,7 @@ struct DBConfig { DUCKDB_API void SetOptionByName(const string &name, const Value &value); DUCKDB_API void SetOptionsByName(const case_insensitive_map_t &values); DUCKDB_API void ResetOption(optional_ptr db, const ConfigurationOption &option); - DUCKDB_API void SetOption(const string &name, Value value); + DUCKDB_API void SetOption(const String &name, Value value); DUCKDB_API void ResetOption(const String &name); DUCKDB_API void ResetGenericOption(const String &name); static LogicalType ParseLogicalType(const string &type); diff --git a/src/duckdb/src/include/duckdb/main/connection.hpp b/src/duckdb/src/include/duckdb/main/connection.hpp index c27d84d21..1c88757fc 100644 --- a/src/duckdb/src/include/duckdb/main/connection.hpp +++ b/src/duckdb/src/include/duckdb/main/connection.hpp @@ -50,7 +50,6 @@ class Connection { DUCKDB_API ~Connection(); shared_ptr context; - warning_callback_t warning_cb; public: //! Returns query profiling information for the current query @@ -80,13 +79,18 @@ class Connection { //! MaterializedQueryResult. The result can be stepped through with calls to Fetch(). Note that there can only be //! one active StreamQueryResult per Connection object. Calling SendQuery() will invalidate any previously existing //! StreamQueryResult. - DUCKDB_API unique_ptr SendQuery(const string &query); + DUCKDB_API unique_ptr + SendQuery(const string &query, QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); + DUCKDB_API unique_ptr + SendQuery(unique_ptr statement, + QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); //! Issues a query to the database and materializes the result (if necessary). Always returns a //! MaterializedQueryResult. DUCKDB_API unique_ptr Query(const string &query); //! Issues a query to the database and materializes the result (if necessary). Always returns a //! MaterializedQueryResult. - DUCKDB_API unique_ptr Query(unique_ptr statement); + DUCKDB_API unique_ptr + Query(unique_ptr statement, QueryResultMemoryType memory_type = QueryResultMemoryType::IN_MEMORY); // prepared statements template unique_ptr Query(const string &query, ARGS... args) { @@ -96,20 +100,25 @@ class Connection { //! Issues a query to the database and returns a Pending Query Result. Note that "query" may only contain //! a single statement. - DUCKDB_API unique_ptr PendingQuery(const string &query, bool allow_stream_result = false); + DUCKDB_API unique_ptr + PendingQuery(const string &query, QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); //! Issues a query to the database and returns a Pending Query Result - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - case_insensitive_map_t &named_values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(const string &query, - case_insensitive_map_t &named_values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(const string &query, vector &values, - bool allow_stream_result = false); - DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, vector &values, - bool allow_stream_result = false); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, case_insensitive_map_t &named_values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(const string &query, case_insensitive_map_t &named_values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr + PendingQuery(const string &query, vector &values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); + DUCKDB_API unique_ptr PendingQuery(const string &query, PendingQueryParameters parameters); + DUCKDB_API unique_ptr + PendingQuery(unique_ptr statement, vector &values, + QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); //! Prepare the specified query, returning a prepared statement object DUCKDB_API unique_ptr Prepare(const string &query); diff --git a/src/duckdb/src/include/duckdb/main/connection_manager.hpp b/src/duckdb/src/include/duckdb/main/connection_manager.hpp index 7fa5c66b5..1c647ce02 100644 --- a/src/duckdb/src/include/duckdb/main/connection_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/connection_manager.hpp @@ -40,7 +40,6 @@ class ConnectionManager { mutex connections_lock; reference_map_t> connections; atomic connection_count; - atomic current_connection_id; }; diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp index 2486d1e0e..7ecb0bc49 100644 --- a/src/duckdb/src/include/duckdb/main/database.hpp +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -17,6 +17,7 @@ #include "duckdb/main/extension_manager.hpp" namespace duckdb { + class BufferManager; class DatabaseManager; class StorageManager; @@ -33,6 +34,7 @@ class DatabaseFileSystem; struct DatabaseCacheEntry; class LogManager; class ExternalFileCache; +class ResultSetManager; class DatabaseInstance : public enable_shared_from_this { friend class DuckDB; @@ -51,6 +53,7 @@ class DatabaseInstance : public enable_shared_from_this { DUCKDB_API DatabaseManager &GetDatabaseManager(); DUCKDB_API FileSystem &GetFileSystem(); DUCKDB_API ExternalFileCache &GetExternalFileCache(); + DUCKDB_API ResultSetManager &GetResultSetManager(); DUCKDB_API TaskScheduler &GetScheduler(); DUCKDB_API ObjectCache &GetObjectCache(); DUCKDB_API ConnectionManager &GetConnectionManager(); @@ -69,7 +72,7 @@ class DatabaseInstance : public enable_shared_from_this { DUCKDB_API SettingLookupResult TryGetCurrentSetting(const string &key, Value &result) const; - DUCKDB_API shared_ptr GetEncryptionUtil() const; + DUCKDB_API shared_ptr GetEncryptionUtil(); shared_ptr CreateAttachedDatabase(ClientContext &context, AttachInfo &info, AttachOptions &options); @@ -90,8 +93,9 @@ class DatabaseInstance : public enable_shared_from_this { unique_ptr extension_manager; ValidChecker db_validity; unique_ptr db_file_system; - shared_ptr log_manager; + unique_ptr log_manager; unique_ptr external_file_cache; + unique_ptr result_set_manager; duckdb_ext_api_v1 (*create_api_v1)(); }; diff --git a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp index 1912a90bf..3af2f1873 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp @@ -12,33 +12,42 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/reference_map.hpp" namespace duckdb { struct AttachInfo; struct AttachOptions; +class DatabaseManager; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - explicit DatabasePathInfo(string name_p) : name(std::move(name_p)) { - } + DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode); string name; + AccessMode access_mode; + reference_set_t attached_databases; + idx_t reference_count = 1; }; //! The DatabaseFilePathManager is used to ensure we only ever open a single database file once class DatabaseFilePathManager { public: idx_t ApproxDatabaseCount() const; - InsertDatabasePathResult InsertDatabasePath(const string &path, const string &name, OnCreateConflict on_conflict, - AttachOptions &options); + InsertDatabasePathResult InsertDatabasePath(DatabaseManager &manager, const string &path, const string &name, + OnCreateConflict on_conflict, AttachOptions &options); //! Erase a database path - indicating we are done with using it void EraseDatabasePath(const string &path); + //! Called when a database is detached, but before it is fully finished being used + void DetachDatabase(DatabaseManager &manager, const string &path); private: //! The lock to add entries to the db_paths map mutable mutex db_paths_lock; - //! A set containing all attached database paths mapped to their attached database name + //! A set containing all attached database path + //! This allows to attach many databases efficiently, and to avoid attaching the + //! same file path twice case_insensitive_map_t db_paths; }; diff --git a/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp b/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp index 8fa226b2d..a8a14c416 100644 --- a/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp +++ b/src/duckdb/src/include/duckdb/main/db_instance_cache.hpp @@ -26,6 +26,8 @@ struct DatabaseCacheEntry { mutex update_database_mutex; }; +enum class CacheBehavior { AUTOMATIC, ALWAYS_CACHE, NEVER_CACHE }; + class DBInstanceCache { public: DBInstanceCache(); @@ -41,6 +43,9 @@ class DBInstanceCache { //! Either returns an existing entry, or creates and caches a new DB Instance shared_ptr GetOrCreateInstance(const string &database, DBConfig &config_dict, bool cache_instance, const std::function &on_create = nullptr); + shared_ptr GetOrCreateInstance(const string &database, DBConfig &config_dict, + CacheBehavior cache_behavior = CacheBehavior::AUTOMATIC, + const std::function &on_create = nullptr); private: shared_ptr path_manager; diff --git a/src/duckdb/src/include/duckdb/main/error_manager.hpp b/src/duckdb/src/include/duckdb/main/error_manager.hpp index aaedffd4b..065f6399a 100644 --- a/src/duckdb/src/include/duckdb/main/error_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/error_manager.hpp @@ -34,38 +34,39 @@ enum class ErrorType : uint16_t { class ErrorManager { public: template - string FormatException(ErrorType error_type, ARGS... params) { + string FormatException(ErrorType error_type, ARGS &&...params) { vector values; - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } DUCKDB_API string FormatExceptionRecursive(ErrorType error_type, vector &values); template string FormatExceptionRecursive(ErrorType error_type, vector &values, T param, - ARGS... params) { + ARGS &&...params) { values.push_back(ExceptionFormatValue::CreateFormatValue(param)); - return FormatExceptionRecursive(error_type, values, params...); + return FormatExceptionRecursive(error_type, values, std::forward(params)...); } template - static string FormatException(ClientContext &context, ErrorType error_type, ARGS... params) { - return Get(context).FormatException(error_type, params...); + static string FormatException(ClientContext &context, ErrorType error_type, ARGS &&...params) { + return Get(context).FormatException(error_type, std::forward(params)...); } DUCKDB_API static InvalidInputException InvalidUnicodeError(const String &input, const string &context); DUCKDB_API static FatalException InvalidatedDatabase(ClientContext &context, const string &invalidated_msg); + DUCKDB_API static TransactionException InvalidatedTransaction(ClientContext &context); //! Adds a custom error for a specific error type void AddCustomError(ErrorType type, string new_error); DUCKDB_API static ErrorManager &Get(ClientContext &context); + DUCKDB_API static ErrorManager &Get(DatabaseInstance &context); private: map custom_errors; }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp b/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp index 70bbf455f..b5e94e7eb 100644 --- a/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp +++ b/src/duckdb/src/include/duckdb/main/extension/generated_extension_loader.hpp @@ -18,9 +18,6 @@ namespace duckdb { -//! Looks through the CMake-generated list of extensions that are linked into DuckDB currently to try load -bool TryLoadLinkedExtension(DuckDB &db, const string &extension); - vector LinkedExtensions(); vector LoadedExtensionTestPaths(); diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index a32331c9b..cb95742f5 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -42,7 +42,6 @@ struct ExtensionFunctionOverloadEntry { static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"!__postfix", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"&", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, - {"&&", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"**", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"->>", "json", CatalogType::SCALAR_FUNCTION_ENTRY}, {"<->", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -69,8 +68,10 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"approx_top_k", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_max_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_max_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"arg_min_null", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"arg_min_nulls_last", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmax", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"argmin", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"array_agg", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, @@ -475,6 +476,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"ord", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"parquet_bloom_probe", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_file_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, + {"parquet_full_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_kv_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_metadata", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, {"parquet_scan", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -547,6 +549,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"sin", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"sinh", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"skewness", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"sleep_ms", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"sql_auto_complete", "autocomplete", CatalogType::TABLE_FUNCTION_ENTRY}, {"sqlite_attach", "sqlite_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"sqlite_query", "sqlite_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -599,6 +602,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"st_envelope", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_envelope_agg", "spatial", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"st_equals", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"st_expand", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_extent", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, {"st_extent_agg", "spatial", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"st_extent_approx", "spatial", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -721,7 +725,9 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"string_agg", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"strpos", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"struct_insert", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"struct_keys", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"struct_update", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"struct_values", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"sum", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"sum_no_overflow", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"sumkahan", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, @@ -779,6 +785,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"var_pop", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"var_samp", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"variance", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"variant_to_parquet_variant", "parquet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vector_type", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"version", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vss_join", "vss", CatalogType::TABLE_MACRO_ENTRY}, @@ -1080,7 +1087,6 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"ui_remote_url", "ui"}, {"unsafe_disable_etag_checks", "httpfs"}, {"unsafe_enable_version_guessing", "iceberg"}, - {"variant_legacy_encoding", "parquet"}, }; // END_OF_EXTENSION_SETTINGS static constexpr ExtensionEntry EXTENSION_SECRET_TYPES[] = { diff --git a/src/duckdb/src/include/duckdb/main/extension_helper.hpp b/src/duckdb/src/include/duckdb/main/extension_helper.hpp index e1037acea..480bc398f 100644 --- a/src/duckdb/src/include/duckdb/main/extension_helper.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_helper.hpp @@ -93,7 +93,7 @@ struct ExtensionInstallOptions { class ExtensionHelper { public: static void LoadAllExtensions(DuckDB &db); - + static vector LoadedExtensionTestPaths(); static ExtensionLoadResult LoadExtension(DuckDB &db, const std::string &extension); //! Install an extension @@ -122,9 +122,9 @@ class ExtensionHelper { static string ExtensionDirectory(ClientContext &context); static string ExtensionDirectory(DatabaseInstance &db, FileSystem &fs); - // Get the extension directory path - static string GetExtensionDirectoryPath(ClientContext &context); - static string GetExtensionDirectoryPath(DatabaseInstance &db, FileSystem &fs); + // Get all extension directory paths + static vector GetExtensionDirectoryPath(ClientContext &context); + static vector GetExtensionDirectoryPath(DatabaseInstance &db, FileSystem &fs); static bool CheckExtensionSignature(FileHandle &handle, ParsedExtensionMetaData &parsed_metadata, const bool allow_community_extensions); @@ -243,7 +243,7 @@ class ExtensionHelper { ExtensionInstallOptions &options, optional_ptr context = nullptr); static const vector PathComponents(); - static string DefaultExtensionFolder(FileSystem &fs); + static vector DefaultExtensionFolders(FileSystem &fs); static bool AllowAutoInstall(const string &extension); static ExtensionInitResult InitialLoad(DatabaseInstance &db, FileSystem &fs, const string &extension); static bool TryInitialLoad(DatabaseInstance &db, FileSystem &fs, const string &extension, diff --git a/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp index 25c50d980..44c3cd67b 100644 --- a/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/materialized_query_result.hpp @@ -30,10 +30,6 @@ class MaterializedQueryResult : public QueryResult { DUCKDB_API explicit MaterializedQueryResult(ErrorData error); public: - //! Fetches a DataChunk from the query result. - //! This will consume the result (i.e. the result can only be scanned once with this function) - DUCKDB_API unique_ptr Fetch() override; - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; DUCKDB_API string ToBox(ClientContext &context, const BoxRendererConfig &config) override; @@ -56,6 +52,9 @@ class MaterializedQueryResult : public QueryResult { //! Takes ownership of the collection, 'collection' is null after this operation unique_ptr TakeCollection(); +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: unique_ptr collection; //! Row collection, only created if GetValue is called diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp index b40a9addb..84f50d654 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp @@ -45,7 +45,9 @@ class PreparedStatementData { //! The map of parameter index to the actual value entry bound_parameter_map_t value_map; //! Whether we are creating a streaming result or not - bool is_streaming = false; + QueryResultOutputType output_type; + //! Whether we are creating a buffer-managed result or not + QueryResultMemoryType memory_type; public: void CheckParameterCount(idx_t parameter_count); diff --git a/src/duckdb/src/include/duckdb/main/profiling_info.hpp b/src/duckdb/src/include/duckdb/main/profiling_info.hpp index 904f0205d..a3f160957 100644 --- a/src/duckdb/src/include/duckdb/main/profiling_info.hpp +++ b/src/duckdb/src/include/duckdb/main/profiling_info.hpp @@ -32,9 +32,6 @@ class ProfilingInfo { profiler_settings_t expanded_settings; //! Contains all enabled metrics. profiler_metrics_t metrics; - //! Additional metrics. - // FIXME: move to metrics. - InsertionOrderPreservingMap extra_info; public: ProfilingInfo() = default; @@ -42,31 +39,27 @@ class ProfilingInfo { ProfilingInfo(ProfilingInfo &) = default; ProfilingInfo &operator=(ProfilingInfo const &) = default; -public: - static profiler_settings_t DefaultSettings(); - static profiler_settings_t DefaultRootSettings(); - static profiler_settings_t DefaultOperatorSettings(); - public: void ResetMetrics(); //! Returns true, if the query profiler must collect this metric. - static bool Enabled(const profiler_settings_t &settings, const MetricsType metric); + static bool Enabled(const profiler_settings_t &settings, const MetricType metric); //! Expand metrics depending on the collection of other metrics. - static void Expand(profiler_settings_t &settings, const MetricsType metric); + static void Expand(profiler_settings_t &settings, const MetricType metric); public: - string GetMetricAsString(const MetricsType metric) const; + string GetMetricAsString(const MetricType metric) const; + void WriteMetricsToLog(ClientContext &context); void WriteMetricsToJSON(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *destination); public: template - METRIC_TYPE GetMetricValue(const MetricsType type) const { + METRIC_TYPE GetMetricValue(const MetricType type) const { auto val = metrics.at(type); return val.GetValue(); } template - void MetricUpdate(const MetricsType type, const Value &value, + void MetricUpdate(const MetricType type, const Value &value, const std::function &update_fun) { if (metrics.find(type) == metrics.end()) { metrics[type] = value; @@ -77,36 +70,52 @@ class ProfilingInfo { } template - void MetricUpdate(const MetricsType type, const METRIC_TYPE &value, + void MetricUpdate(const MetricType type, const METRIC_TYPE &value, const std::function &update_fun) { auto new_value = Value::CreateValue(value); MetricUpdate(type, new_value, update_fun); } template - void MetricSum(const MetricsType type, const Value &value) { + void MetricSum(const MetricType type, const Value &value) { MetricUpdate(type, value, [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { return old_value + new_value; }); } template - void MetricSum(const MetricsType type, const METRIC_TYPE &value) { + void MetricSum(const MetricType type, const METRIC_TYPE &value) { auto new_value = Value::CreateValue(value); return MetricSum(type, new_value); } template - void MetricMax(const MetricsType type, const Value &value) { + void MetricMax(const MetricType type, const Value &value) { MetricUpdate(type, value, [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { return MaxValue(old_value, new_value); }); } + template - void MetricMax(const MetricsType type, const METRIC_TYPE &value) { + void MetricMax(const MetricType type, const METRIC_TYPE &value) { auto new_value = Value::CreateValue(value); return MetricMax(type, new_value); } }; +// Specialization for InsertionOrderPreservingMap +template <> +inline InsertionOrderPreservingMap +ProfilingInfo::GetMetricValue>(const MetricType type) const { + auto val = metrics.at(type); + InsertionOrderPreservingMap result; + auto children = MapValue::GetChildren(val); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + result.insert(key, value); + } + return result; +} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/profiling_utils.hpp b/src/duckdb/src/include/duckdb/main/profiling_utils.hpp new file mode 100644 index 000000000..037619309 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/profiling_utils.hpp @@ -0,0 +1,138 @@ +//===----------------------------------------------------------------------===// +// +// DuckDB +// +// duckdb/main/profiling_utils.hpp +// +// This file is automatically generated by scripts/generate_metric_enums.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/metric_type.hpp" +#include "duckdb/main/profiling_node.hpp" +#include "duckdb/main/profiling_info.hpp" +#include "duckdb/common/profiler.hpp" + +namespace duckdb_yyjson { +struct yyjson_mut_doc; +struct yyjson_mut_val; +} // namespace duckdb_yyjson + +namespace duckdb { + +struct ActiveTimer; + +// Top level query metrics +struct QueryMetrics { +public: + QueryMetrics() { + Reset(); + } + + ProfilingInfo query_global_info; + + std::string query_name; + unique_ptr latency_timer; + +public: + void UpdateMetric(const MetricType metric, idx_t addition) { + active_metrics[GetMetricsIndex(metric)] += addition; + } + + idx_t GetMetricValue(const MetricType metric) const { + return active_metrics[GetMetricsIndex(metric)]; + } + + double GetMetricInSeconds(const MetricType metric) const { + return static_cast(active_metrics[GetMetricsIndex(metric)]) / 1e9; + } + + void Reset() { + for(idx_t i = 0; i < ACTIVELY_TRACKED_METRICS; i++) { + active_metrics[i] = 0; + } + } + + void Merge(const QueryMetrics &other) { + for(idx_t i = 0; i < ACTIVELY_TRACKED_METRICS; i++) { + active_metrics[i] += other.active_metrics[i]; + } + } + + static idx_t GetMetricsIndex(MetricType type) { + switch(type) { + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: return 0; + case MetricType::ATTACH_REPLAY_WAL_LATENCY: return 1; + case MetricType::CHECKPOINT_LATENCY: return 2; + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: return 3; + case MetricType::LATENCY: return 4; + case MetricType::WAITING_TO_ATTACH_LATENCY: return 5; + case MetricType::WRITE_TO_WAL_LATENCY: return 6; + case MetricType::TOTAL_BYTES_READ: return 7; + case MetricType::TOTAL_BYTES_WRITTEN: return 8; + case MetricType::TOTAL_MEMORY_ALLOCATED: return 9; + case MetricType::WAL_REPLAY_ENTRY_COUNT: return 10; + default: + throw InternalException("MetricType %s is not actively tracked.", EnumUtil::ToString(type)); + } + } + +private: + static constexpr const idx_t ACTIVELY_TRACKED_METRICS = 11; + + atomic active_metrics[ACTIVELY_TRACKED_METRICS]; +}; + +class ProfilingUtils { +public: + static void SetMetricToDefault(profiler_metrics_t &metrics, const MetricType &type); + static void MetricToJson(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *dest, const char *key_ptr, profiler_metrics_t &metrics, const MetricType &type); + static void CollectMetrics(const MetricType &type, QueryMetrics &query_metrics, Value &metric, ProfilingNode &node, ProfilingInfo &child_info); +}; + +struct ActiveTimer { +public: + ActiveTimer(QueryMetrics &query_metrics, const MetricType metric, const bool is_active = true) : query_metrics(query_metrics), metric(metric), is_active(is_active) { + // start on constructor + if (!is_active) { + return; + } + profiler.Start(); + } + + ~ActiveTimer() { + if (is_active) { + // automatically end in destructor + EndTimer(); + } + } + + // Automatically called in the destructor. + void EndTimer() { + if (!is_active) { + return; + } + // stop profiling and report + is_active = false; + profiler.End(); + query_metrics.UpdateMetric(metric, profiler.ElapsedNanos()); + } + + void Reset() { + if (!is_active) { + return; + } + profiler.Reset(); + is_active = false; + } + +private: + QueryMetrics &query_metrics; + const MetricType metric; + Profiler profiler; + bool is_active; +}; + +} diff --git a/src/duckdb/src/include/duckdb/main/query_parameters.hpp b/src/duckdb/src/include/duckdb/main/query_parameters.hpp new file mode 100644 index 000000000..d9bb42a3b --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/query_parameters.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/query_parameters.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +enum class QueryResultOutputType : uint8_t { FORCE_MATERIALIZED, ALLOW_STREAMING }; + +enum class QueryResultMemoryType : uint8_t { IN_MEMORY, BUFFER_MANAGED }; + +struct QueryParameters { + QueryParameters() { + } + QueryParameters(bool allow_streaming) // NOLINT: allow implicit conversion + : output_type(allow_streaming ? QueryResultOutputType::ALLOW_STREAMING + : QueryResultOutputType::FORCE_MATERIALIZED) { + } + QueryParameters(QueryResultOutputType output_type) // NOLINT: allow implicit conversion + : output_type(output_type) { + } + QueryResultOutputType output_type = QueryResultOutputType::FORCE_MATERIALIZED; + QueryResultMemoryType memory_type = QueryResultMemoryType::IN_MEMORY; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_profiler.hpp b/src/duckdb/src/include/duckdb/main/query_profiler.hpp index 0f7b8812d..e68841e8b 100644 --- a/src/duckdb/src/include/duckdb/main/query_profiler.hpp +++ b/src/duckdb/src/include/duckdb/main/query_profiler.hpp @@ -21,10 +21,8 @@ #include "duckdb/common/winapi.hpp" #include "duckdb/execution/expression_executor_state.hpp" #include "duckdb/execution/physical_operator.hpp" -#include "duckdb/main/profiling_info.hpp" #include "duckdb/main/profiling_node.hpp" - -#include +#include "duckdb/main/profiling_utils.hpp" namespace duckdb { @@ -33,6 +31,7 @@ class ExpressionExecutor; class ProfilingNode; class PhysicalOperator; class SQLStatement; +struct ActiveTimer; enum class ProfilingCoverage : uint8_t { SELECT = 0, ALL = 1 }; @@ -94,7 +93,6 @@ class OperatorProfiler { DUCKDB_API void Flush(const PhysicalOperator &phys_op); DUCKDB_API OperatorInformation &GetOperatorInfo(const PhysicalOperator &phys_op); DUCKDB_API bool OperatorInfoIsInitialized(const PhysicalOperator &phys_op); - DUCKDB_API void AddExtraInfo(InsertionOrderPreservingMap extra_info); public: ClientContext &context; @@ -113,22 +111,6 @@ class OperatorProfiler { reference_map_t operator_infos; }; -//! Top level query metrics. -struct QueryMetrics { - QueryMetrics() : total_bytes_read(0), total_bytes_written(0) {}; - - ProfilingInfo query_global_info; - - //! The SQL string of the query - string query; - //! The timer used to time the excution time of the entire query - Profiler latency; - //! The total bytes read by the file system - atomic total_bytes_read; - //! The total bytes written by the file system - atomic total_bytes_written; -}; - //! QueryProfiler collects the profiling metrics of a query. class QueryProfiler { public: @@ -138,9 +120,6 @@ class QueryProfiler { DUCKDB_API explicit QueryProfiler(ClientContext &context); public: - //! Propagate save_location, enabled, detailed_enabled and automatic_print_format. - void Propagate(QueryProfiler &qp); - DUCKDB_API bool IsEnabled() const; DUCKDB_API bool IsDetailedEnabled() const; DUCKDB_API ProfilerPrintFormat GetPrintFormat(ExplainFormat format = ExplainFormat::DEFAULT) const; @@ -154,19 +133,20 @@ class QueryProfiler { DUCKDB_API void StartQuery(const string &query, bool is_explain_analyze = false, bool start_at_optimizer = false); DUCKDB_API void EndQuery(); - //! Adds nr_bytes bytes to the total bytes read. - DUCKDB_API void AddBytesRead(const idx_t nr_bytes); - //! Adds nr_bytes bytes to the total bytes written. - DUCKDB_API void AddBytesWritten(const idx_t nr_bytes); + //! Adds amount to a specific metric type. + DUCKDB_API void AddToCounter(MetricType type, const idx_t amount); + + //! Start/End a timer for a specific metric type. + DUCKDB_API ActiveTimer StartTimer(MetricType type); DUCKDB_API void StartExplainAnalyze(); //! Adds the timings gathered by an OperatorProfiler to this query profiler DUCKDB_API void Flush(OperatorProfiler &profiler); //! Adds the top level query information to the global profiler. - DUCKDB_API void SetInfo(const double &blocked_thread_time); + DUCKDB_API void SetBlockedTime(const double &blocked_thread_time); - DUCKDB_API void StartPhase(MetricsType phase_metric); + DUCKDB_API void StartPhase(MetricType phase_metric); DUCKDB_API void EndPhase(); DUCKDB_API void Initialize(const PhysicalOperator &root); @@ -180,11 +160,15 @@ class QueryProfiler { DUCKDB_API string ToString(ExplainFormat format = ExplainFormat::DEFAULT) const; DUCKDB_API string ToString(ProfilerPrintFormat format) const; - static InsertionOrderPreservingMap JSONSanitize(const InsertionOrderPreservingMap &input); + // Sanitize a Value::MAP + static Value JSONSanitize(const Value &input); static string JSONSanitize(const string &text); static string DrawPadded(const string &str, idx_t width); + DUCKDB_API void ToLog() const; DUCKDB_API string ToJSON() const; DUCKDB_API void WriteToFile(const char *path, string &info) const; + DUCKDB_API idx_t GetBytesRead() const; + DUCKDB_API idx_t GetBytesWritten() const; idx_t OperatorSize() { return tree_map.size(); @@ -241,11 +225,11 @@ class QueryProfiler { //! The timer used to time the individual phases of the planning process Profiler phase_profiler; //! A mapping of the phase names to the timings - using PhaseTimingStorage = unordered_map; + using PhaseTimingStorage = unordered_map; PhaseTimingStorage phase_timings; using PhaseTimingItem = PhaseTimingStorage::value_type; //! The stack of currently active phases - vector phase_stack; + vector phase_stack; private: void MoveOptimizerPhasesToRoot(); diff --git a/src/duckdb/src/include/duckdb/main/query_result.hpp b/src/duckdb/src/include/duckdb/main/query_result.hpp index f85629428..68ce7c8ac 100644 --- a/src/duckdb/src/include/duckdb/main/query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/query_result.hpp @@ -44,7 +44,7 @@ class BaseQueryResult { DUCKDB_API void SetError(ErrorData error); DUCKDB_API bool HasError() const; DUCKDB_API const ExceptionType &GetErrorType() const; - DUCKDB_API const std::string &GetError(); + DUCKDB_API const std::string &GetError() const; DUCKDB_API ErrorData &GetErrorObject(); DUCKDB_API idx_t ColumnCount(); @@ -98,10 +98,10 @@ class QueryResult : public BaseQueryResult { DUCKDB_API const string &ColumnName(idx_t index) const; //! Fetches a DataChunk of normalized (flat) vectors from the query result. //! Returns nullptr if there are no more results to fetch. - DUCKDB_API virtual unique_ptr Fetch(); + DUCKDB_API unique_ptr Fetch(); //! Fetches a DataChunk from the query result. The vectors are not normalized and hence any vector types can be //! returned. - DUCKDB_API virtual unique_ptr FetchRaw() = 0; + DUCKDB_API unique_ptr FetchRaw(); //! Converts the QueryResult to a string DUCKDB_API virtual string ToString() = 0; //! Converts the QueryResult to a box-rendered string @@ -125,57 +125,70 @@ class QueryResult : public BaseQueryResult { } } +protected: + DUCKDB_API virtual unique_ptr FetchInternal() = 0; + private: class QueryResultIterator; class QueryResultRow { + friend class QueryResultIterator; + public: - explicit QueryResultRow(QueryResultIterator &iterator_p, idx_t row_idx) : iterator(iterator_p), row(0) { + explicit QueryResultRow() : row(0) { } - QueryResultIterator &iterator; - idx_t row; - bool IsNull(idx_t col_idx) const { - return iterator.chunk->GetValue(col_idx, row).IsNull(); + return chunk->GetValue(col_idx, row).IsNull(); } template T GetValue(idx_t col_idx) const { - return iterator.chunk->GetValue(col_idx, row).GetValue(); + return chunk->GetValue(col_idx, row).GetValue(); + } + Value GetBaseValue(idx_t col_idx) const { + return chunk->GetValue(col_idx, row); + } + DataChunk &GetChunk() const { + return *chunk; } + idx_t GetRowInChunk() const { + return row; + } + + private: + shared_ptr chunk; + idx_t row; }; //! The row-based query result iterator. Invoking the class QueryResultIterator { public: - explicit QueryResultIterator(optional_ptr result_p) - : current_row(*this, 0), result(result_p), base_row(0) { + explicit QueryResultIterator(optional_ptr result_p = nullptr) : result(result_p), base_row(0) { if (result) { - chunk = shared_ptr(result->Fetch().release()); - if (!chunk) { + current_row.chunk = shared_ptr(result->Fetch().release()); + if (!current_row.chunk) { result = nullptr; } } } QueryResultRow current_row; - shared_ptr chunk; optional_ptr result; idx_t base_row; public: void Next() { - if (!chunk) { + if (!current_row.chunk) { return; } current_row.row++; - if (current_row.row >= chunk->size()) { - base_row += chunk->size(); - chunk = shared_ptr(result->Fetch().release()); + if (current_row.row >= current_row.chunk->size()) { + base_row += current_row.chunk->size(); + current_row.chunk = shared_ptr(result->Fetch().release()); current_row.row = 0; - if (!chunk || chunk->size() == 0) { + if (!current_row.chunk || current_row.chunk->size() == 0) { // exhausted all rows base_row = 0; result = nullptr; - chunk.reset(); + current_row.chunk.reset(); } } } @@ -187,16 +200,21 @@ class QueryResult : public BaseQueryResult { bool operator!=(const QueryResultIterator &other) const { return result != other.result || base_row != other.base_row || current_row.row != other.current_row.row; } + bool operator==(const QueryResultIterator &other) const { + return !(*this != other); + } const QueryResultRow &operator*() const { return current_row; } }; public: - QueryResultIterator begin() { // NOLINT: match stl API + using iterator = QueryResultIterator; + + iterator begin() { // NOLINT: match stl API return QueryResultIterator(this); } - QueryResultIterator end() { // NOLINT: match stl API + iterator end() { // NOLINT: match stl API return QueryResultIterator(nullptr); } diff --git a/src/duckdb/src/include/duckdb/main/relation.hpp b/src/duckdb/src/include/duckdb/main/relation.hpp index 9d9e67686..94450b0be 100644 --- a/src/duckdb/src/include/duckdb/main/relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation.hpp @@ -78,7 +78,8 @@ class Relation : public enable_shared_from_this { public: DUCKDB_API virtual const vector &Columns() = 0; - DUCKDB_API virtual unique_ptr GetQueryNode(); + DUCKDB_API virtual unique_ptr GetQueryNode() = 0; + DUCKDB_API virtual string GetQuery(); DUCKDB_API virtual BoundStatement Bind(Binder &binder); DUCKDB_API virtual string GetAlias(); @@ -161,19 +162,27 @@ class Relation : public enable_shared_from_this { //! Insert the data from this relation into a table DUCKDB_API shared_ptr InsertRel(const string &schema_name, const string &table_name); + DUCKDB_API shared_ptr InsertRel(const string &catalog_name, const string &schema_name, + const string &table_name); DUCKDB_API void Insert(const string &table_name); DUCKDB_API void Insert(const string &schema_name, const string &table_name); + DUCKDB_API void Insert(const string &catalog_name, const string &schema_name, const string &table_name); //! Insert a row (i.e.,list of values) into a table - DUCKDB_API void Insert(const vector> &values); - DUCKDB_API void Insert(vector>> &&expressions); + DUCKDB_API virtual void Insert(const vector> &values); + DUCKDB_API virtual void Insert(vector>> &&expressions); //! Create a table and insert the data from this relation into that table DUCKDB_API shared_ptr CreateRel(const string &schema_name, const string &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); + DUCKDB_API shared_ptr CreateRel(const string &catalog_name, const string &schema_name, + const string &table_name, bool temporary = false, + OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); DUCKDB_API void Create(const string &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); DUCKDB_API void Create(const string &schema_name, const string &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); + DUCKDB_API void Create(const string &catalog_name, const string &schema_name, const string &table_name, + bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); //! Write a relation to a CSV file DUCKDB_API shared_ptr diff --git a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp index 7d5462941..cfc0e243a 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp @@ -16,8 +16,11 @@ class CreateTableRelation : public Relation { public: CreateTableRelation(shared_ptr child, string schema_name, string table_name, bool temporary, OnCreateConflict on_conflict); + CreateTableRelation(shared_ptr child, string catalog_name, string schema_name, string table_name, + bool temporary, OnCreateConflict on_conflict); shared_ptr child; + string catalog_name; string schema_name; string table_name; vector columns; @@ -26,6 +29,8 @@ class CreateTableRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp index cb826a86c..aa09b0def 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp @@ -26,6 +26,8 @@ class CreateViewRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp index c07445ba4..0c25c6576 100644 --- a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp @@ -26,6 +26,8 @@ class DeleteRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp index 888583b2b..96be08d8f 100644 --- a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp @@ -24,6 +24,8 @@ class ExplainRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp index 3695cde7b..41756488f 100644 --- a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp @@ -15,14 +15,18 @@ namespace duckdb { class InsertRelation : public Relation { public: InsertRelation(shared_ptr child, string schema_name, string table_name); + InsertRelation(shared_ptr child, string catalog_name, string schema_name, string table_name); shared_ptr child; + string catalog_name; string schema_name; string table_name; vector columns; public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp index bdb035652..b1be001b9 100644 --- a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp @@ -28,6 +28,7 @@ class QueryRelation : public Relation { public: static unique_ptr ParseStatement(ClientContext &context, const string &query, const string &error); unique_ptr GetQueryNode() override; + string GetQuery() override; unique_ptr GetTableRef() override; BoundStatement Bind(Binder &binder) override; diff --git a/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp index 9c2fddcec..a3184dafa 100644 --- a/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp @@ -29,6 +29,8 @@ class TableRelation : public Relation { unique_ptr GetTableRef() override; + void Insert(const vector> &values) override; + void Insert(vector>> &&expressions) override; void Update(const string &update, const string &condition = string()) override; void Update(vector column_names, vector> &&update, unique_ptr condition = nullptr) override; diff --git a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp index 58ad203b2..91eac246e 100644 --- a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp @@ -29,6 +29,8 @@ class UpdateRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp index 99d2ebe8e..cf0853ff3 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp @@ -23,6 +23,8 @@ class WriteCSVRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp index d32089212..138eee7c7 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp @@ -24,6 +24,8 @@ class WriteParquetRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/result_set_manager.hpp b/src/duckdb/src/include/duckdb/main/result_set_manager.hpp new file mode 100644 index 000000000..0be2a4b88 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/result_set_manager.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/result_set_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +class DatabaseInstance; +class ClientContext; +class BlockHandle; +class ColumnDataAllocator; + +class ManagedResultSet : public enable_shared_from_this { +public: + ManagedResultSet(); + ManagedResultSet(const weak_ptr &db, vector> &handles); + +public: + bool IsValid() const; + shared_ptr GetDatabase() const; + vector> &GetHandles(); + +private: + bool valid; + weak_ptr db; + optional_ptr>> handles; +}; + +class ResultSetManager { +public: + explicit ResultSetManager(DatabaseInstance &db); + +public: + static ResultSetManager &Get(ClientContext &context); + static ResultSetManager &Get(DatabaseInstance &db); + ManagedResultSet Add(ColumnDataAllocator &allocator); + void Remove(ColumnDataAllocator &allocator); + +private: + mutex lock; + weak_ptr db; + reference_map_t>>> open_results; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/secret/secret.hpp b/src/duckdb/src/include/duckdb/main/secret/secret.hpp index ed8034413..fd8a1b241 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret.hpp @@ -296,7 +296,9 @@ class KeyValueSecretReader { Value result; auto lookup_result = TryGetSecretKeyOrSetting(secret_key, setting_name, result); if (lookup_result) { - value_out = result.GetValue(); + if (!result.IsNull()) { + value_out = result.GetValue(); + } } return lookup_result; } diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 383d5533b..b3b7e10cd 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -95,6 +95,18 @@ struct AllowExtensionsMetadataMismatchSetting { static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; +struct AllowParserOverrideExtensionSetting { + using RETURN_TYPE = string; + static constexpr const char *Name = "allow_parser_override_extension"; + static constexpr const char *Description = "Allow extensions to override the current parser"; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static bool OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input); + static bool OnGlobalReset(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct AllowPersistentSecretsSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "allow_persistent_secrets"; @@ -237,6 +249,17 @@ struct AutoloadKnownExtensionsSetting { static Value GetSetting(const ClientContext &context); }; +struct BlockAllocatorMemorySetting { + using RETURN_TYPE = string; + static constexpr const char *Name = "block_allocator_memory"; + static constexpr const char *Description = "Physical memory that the block allocator is allowed to use (this " + "memory is never freed and cannot be reduced)."; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct CatalogErrorMaxSchemasSetting { using RETURN_TYPE = idx_t; static constexpr const char *Name = "catalog_error_max_schemas"; @@ -308,6 +331,15 @@ struct DebugCheckpointAbortSetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; +struct DebugCheckpointSleepMsSetting { + using RETURN_TYPE = idx_t; + static constexpr const char *Name = "debug_checkpoint_sleep_ms"; + static constexpr const char *Description = "DEBUG SETTING: time to sleep before a checkpoint"; + static constexpr const char *InputType = "UBIGINT"; + static constexpr const char *DefaultValue = "0"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; +}; + struct DebugForceExternalSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "debug_force_external"; @@ -329,6 +361,17 @@ struct DebugForceNoCrossProductSetting { static constexpr SetScope DefaultScope = SetScope::SESSION; }; +struct DebugPhysicalTableScanExecutionStrategySetting { + using RETURN_TYPE = PhysicalTableScanExecutionStrategy; + static constexpr const char *Name = "debug_physical_table_scan_execution_strategy"; + static constexpr const char *Description = + "DEBUG SETTING: force use of given strategy for executing physical table scans"; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "DEFAULT"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct DebugSkipCheckpointOnCommitSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "debug_skip_checkpoint_on_commit"; @@ -338,6 +381,15 @@ struct DebugSkipCheckpointOnCommitSetting { static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; +struct DebugVerifyBlocksSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "debug_verify_blocks"; + static constexpr const char *Description = "DEBUG SETTING: verify block metadata during checkpointing"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; +}; + struct DebugVerifyVectorSetting { using RETURN_TYPE = DebugVectorVerification; static constexpr const char *Name = "debug_verify_vector"; @@ -645,7 +697,7 @@ struct ExperimentalMetadataReuseSetting { static constexpr const char *Name = "experimental_metadata_reuse"; static constexpr const char *Description = "EXPERIMENTAL: Re-use row group and table metadata when checkpointing."; static constexpr const char *InputType = "BOOLEAN"; - static constexpr const char *DefaultValue = "false"; + static constexpr const char *DefaultValue = "true"; static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; @@ -659,6 +711,16 @@ struct ExplainOutputSetting { static Value GetSetting(const ClientContext &context); }; +struct ExtensionDirectoriesSetting { + using RETURN_TYPE = vector; + static constexpr const char *Name = "extension_directories"; + static constexpr const char *Description = "Set the directories to store extensions in"; + static constexpr const char *InputType = "VARCHAR[]"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct ExtensionDirectorySetting { using RETURN_TYPE = string; static constexpr const char *Name = "extension_directory"; @@ -711,6 +773,17 @@ struct ForceCompressionSetting { static Value GetSetting(const ClientContext &context); }; +struct ForceVariantShredding { + using RETURN_TYPE = string; + static constexpr const char *Name = "force_variant_shredding"; + static constexpr const char *Description = + "Forces the VARIANT shredding that happens at checkpoint to use the provided schema for the shredding."; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct HomeDirectorySetting { using RETURN_TYPE = string; static constexpr const char *Name = "home_directory"; @@ -1253,6 +1326,17 @@ struct UsernameSetting { static Value GetSetting(const ClientContext &context); }; +struct VariantMinimumShreddingSize { + using RETURN_TYPE = int64_t; + static constexpr const char *Name = "variant_minimum_shredding_size"; + static constexpr const char *Description = "Minimum size of a rowgroup to enable VARIANT shredding, or set to -1 " + "to disable entirely. Defaults to 1/4th of a rowgroup"; + static constexpr const char *InputType = "BIGINT"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct WriteBufferRowGroupCountSetting { using RETURN_TYPE = idx_t; static constexpr const char *Name = "write_buffer_row_group_count"; diff --git a/src/duckdb/src/include/duckdb/main/stream_query_result.hpp b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp index 3c04a364c..775202ea7 100644 --- a/src/duckdb/src/include/duckdb/main/stream_query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/stream_query_result.hpp @@ -44,8 +44,6 @@ class StreamQueryResult : public QueryResult { DUCKDB_API void WaitForTask(); //! Executes a single task within the final pipeline, returning whether or not a chunk is ready to be fetched DUCKDB_API StreamExecutionResult ExecuteTask(); - //! Fetches a DataChunk from the query result. - DUCKDB_API unique_ptr FetchRaw() override; //! Converts the QueryResult to a string DUCKDB_API string ToString() override; //! Materializes the query result and turns it into a materialized query result @@ -59,9 +57,12 @@ class StreamQueryResult : public QueryResult { //! The client context this StreamQueryResult belongs to shared_ptr context; +protected: + DUCKDB_API unique_ptr FetchInternal() override; + private: StreamExecutionResult ExecuteTaskInternal(ClientContextLock &lock); - unique_ptr FetchInternal(ClientContextLock &lock); + unique_ptr FetchNextInternal(ClientContextLock &lock); unique_ptr LockContext(); void CheckExecutableInternal(ClientContextLock &lock); bool IsOpenInternal(ClientContextLock &lock); diff --git a/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp new file mode 100644 index 000000000..8d8e35ea1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/common_subplan_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class Optimizer; +class LogicalOperator; + +//! The CommonSubplanOptimizer optimizer detects common subplans, and converts them to refs of a materialized CTE +class CommonSubplanOptimizer { +public: + explicit CommonSubplanOptimizer(Optimizer &optimizer); + +public: + unique_ptr Optimize(unique_ptr op); + +private: + //! The optimizer + Optimizer &optimizer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp index 90439b11e..97529f6ee 100644 --- a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp @@ -25,6 +25,7 @@ class CTEInlining { public: explicit CTEInlining(Optimizer &optimizer); unique_ptr Optimize(unique_ptr op); + static bool EndsInAggregateOrDistinct(const LogicalOperator &op); private: void TryInlining(unique_ptr &op); diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp index 890b90970..36cdbf59c 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp @@ -50,6 +50,7 @@ class FilterCombiner { //! If this returns true - this sorts "in_list" as a side-effect static bool IsDenseRange(vector &in_list); static bool ContainsNull(vector &in_list); + static bool FindNextLegalUTF8(string &prefix_string); void GenerateFilters(const std::function filter)> &callback); bool HasFilters(); diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp index a35fbaab9..b6cb1e704 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pullup.hpp @@ -30,7 +30,7 @@ class FilterPullup { // only pull up filters when there is a fork bool can_pullup = false; - // identifiy case the branch is a set operation (INTERSECT or EXCEPT) + // identify case the branch is a set operation (INTERSECT or EXCEPT) bool can_add_column = false; private: @@ -40,30 +40,26 @@ class FilterPullup { //! Pull up a LogicalFilter op unique_ptr PullupFilter(unique_ptr op); - //! Pull up filter in a LogicalProjection op unique_ptr PullupProjection(unique_ptr op); - //! Pull up filter in a LogicalCrossProduct op unique_ptr PullupCrossProduct(unique_ptr op); - + //! Pullup a filter in a LogicalJoin unique_ptr PullupJoin(unique_ptr op); - - // PPullup filter in a left join + //! Pullup filter in a left join unique_ptr PullupFromLeft(unique_ptr op); - - // Pullup filter in a inner join + //! Pullup filter in an inner join unique_ptr PullupInnerJoin(unique_ptr op); - - // Pullup filter in LogicalIntersect or LogicalExcept op + //! Pullup filter through a distinct + unique_ptr PullupDistinct(unique_ptr op); + //! Pullup filter in LogicalIntersect or LogicalExcept op unique_ptr PullupSetOperation(unique_ptr op); - + //! Pullup filter in both sides of a join unique_ptr PullupBothSide(unique_ptr op); - // Finish pull up at this operator + //! Finish pull up at this operator unique_ptr FinishPullup(unique_ptr op); - - // special treatment for SetOperations and projections + //! special treatment for SetOperations and projections void ProjectSetOperation(LogicalProjection &proj); }; // end FilterPullup diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp index c2fc87a52..29c2f0ac4 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -96,14 +96,17 @@ class FilterPushdown { unique_ptr FinishPushdown(unique_ptr op); //! Adds a filter to the set of filters. Returns FilterResult::UNSATISFIABLE if the subtree should be stripped, or //! FilterResult::SUCCESS otherwise + + unique_ptr PushFiltersIntoDelimJoin(unique_ptr op); FilterResult AddFilter(unique_ptr expr); //! Extract filter bindings to compare them with expressions in an operator and determine if the filter //! can be pushed down void ExtractFilterBindings(const Expression &expr, vector &bindings); //! Generate filters from the current set of filters stored in the FilterCombiner void GenerateFilters(); - //! if there are filters in this FilterPushdown node, push them into the combiner - void PushFilters(); + //! if there are filters in this FilterPushdown node, push them into the combiner. Returns + //! FilterResult::UNSATISFIABLE if the subtree should be stripped, or FilterResult::SUCCESS otherwise + FilterResult PushFilters(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_elimination.hpp b/src/duckdb/src/include/duckdb/optimizer/join_elimination.hpp new file mode 100644 index 000000000..6702f94e9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_elimination.hpp @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_elimination.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" + +namespace duckdb { +class JoinElimination; +class PipelineInfo; + +struct DistinctGroupRef { + column_binding_set_t distinct_group; + unordered_set ref_column_ids; +}; + +class PipelineInfo { +public: + unique_ptr root = nullptr; + unordered_set ref_table_ids; + unordered_map distinct_groups; + + // pushdown filter condition(ex in table scan operator), + // if have outer table columns then cannot elimination + bool has_filter = false; + + optional_ptr join_parent = nullptr; + idx_t join_index = 0; + +public: + PipelineInfo CreateChild() { + auto result = PipelineInfo(); + result.ref_table_ids = ref_table_ids; + result.distinct_groups = distinct_groups; + return result; + } +}; + +class JoinElimination : public LogicalOperatorVisitor { +public: + explicit JoinElimination() { + } + + void OptimizeChildren(LogicalOperator &op, optional_ptr parent, idx_t idx); + // with specific condition we can eliminate a (left/right, semi, inner) join. + // exemplify left/right join eliminaion condition: + // 1. output can only have outer table columns + // 2. join result cannot filter by inner table columns(ex. in where clause/ having clause ...) + // 3. must ensure each outer row can match at most one inner table row, such as: + // 1) inner table join condition is unique(ex. 1. join conditions have inner table's primary key 2. inner table + // join condition columns contains a whole distinct group) 2) join result columns contains a whole distinct group + unique_ptr Optimize(unique_ptr op); + void OptimizeInternal(unique_ptr op); + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + + unique_ptr CreateChildren() { + auto result = make_uniq(); + result->pipe_info = pipe_info.CreateChild(); + return result; + } + +private: + unique_ptr TryEliminateJoin(); + // void ExtractDistinctReferences(vector &expressions, idx_t target_table_index); + bool ContainDistinctGroup(vector &exprs); + + PipelineInfo pipe_info; + + optional_ptr children_root; + vector> children; + + unique_ptr left_child; + unique_ptr right_child; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_filter_pushdown_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/join_filter_pushdown_optimizer.hpp index aef9ba055..efb8c10f4 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_filter_pushdown_optimizer.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_filter_pushdown_optimizer.hpp @@ -27,6 +27,8 @@ class JoinFilterPushdownOptimizer : public LogicalOperatorVisitor { static void GetPushdownFilterTargets(LogicalOperator &op, vector columns, vector &targets); + static bool IsFiltering(const unique_ptr &op); + private: void GenerateJoinFilters(LogicalComparisonJoin &join); diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp index 539035710..0217b7307 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp @@ -26,36 +26,38 @@ struct DenomInfo { double denominator; }; -struct RelationsToTDom { +struct RelationsSetToStats { //! column binding sets that are equivalent in a join plan. //! if you have A.x = B.y and B.y = C.z, then one set is {A.x, B.y, C.z}. column_binding_set_t equivalent_relations; //! the estimated total domains of the equivalent relations determined using HLL - idx_t tdom_hll; + idx_t distinct_count_hll; //! the estimated total domains of each relation without using HLL - idx_t tdom_no_hll; - bool has_tdom_hll; + idx_t distinct_count_no_hll; + bool has_distinct_count_hll; vector> filters; vector column_names; - explicit RelationsToTDom(const column_binding_set_t &column_binding_set) - : equivalent_relations(column_binding_set), tdom_hll(0), tdom_no_hll(NumericLimits::Maximum()), - has_tdom_hll(false) {}; + explicit RelationsSetToStats(const column_binding_set_t &column_binding_set) + : equivalent_relations(column_binding_set), distinct_count_hll(0), + distinct_count_no_hll(NumericLimits::Maximum()), has_distinct_count_hll(false) {}; }; +// class to wrap a join Filter along with some statistical information about the joined columns class FilterInfoWithTotalDomains { public: - FilterInfoWithTotalDomains(optional_ptr filter_info, RelationsToTDom &relation2tdom) - : filter_info(filter_info), tdom_hll(relation2tdom.tdom_hll), tdom_no_hll(relation2tdom.tdom_no_hll), - has_tdom_hll(relation2tdom.has_tdom_hll) { + FilterInfoWithTotalDomains(optional_ptr filter_info, RelationsSetToStats &relation_set_to_stats) + : filter_info(filter_info), distinct_count_hll(relation_set_to_stats.distinct_count_hll), + distinct_count_no_hll(relation_set_to_stats.distinct_count_no_hll), + has_distinct_count_hll(relation_set_to_stats.has_distinct_count_hll) { } optional_ptr filter_info; - //! the estimated total domains of the equivalent relations determined using HLL - idx_t tdom_hll; + //! the estimated distinct count the joined columns determined using HLL + idx_t distinct_count_hll; //! the estimated total domains of each relation without using HLL - idx_t tdom_no_hll; - bool has_tdom_hll; + idx_t distinct_count_no_hll; + bool has_distinct_count_hll; }; struct Subgraph2Denominator { @@ -91,7 +93,7 @@ class CardinalityEstimator { explicit CardinalityEstimator() {}; private: - vector relations_to_tdoms; + vector relation_set_stats; unordered_map relation_set_2_cardinality; JoinRelationSetManager set_manager; vector relation_stats; @@ -109,8 +111,8 @@ class CardinalityEstimator { T EstimateCardinalityWithSet(JoinRelationSet &new_set); //! used for debugging. - void AddRelationNamesToTdoms(vector &stats); - void PrintRelationToTdomInfo(); + void AddRelationNamesToRelationStats(vector &stats); + void PrintRelationStats(); private: double GetNumerator(JoinRelationSet &set); @@ -128,7 +130,7 @@ class CardinalityEstimator { JoinRelationSet &UpdateNumeratorRelations(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter); - void AddRelationTdom(FilterInfo &filter_info); + void AddRelationStats(FilterInfo &filter_info); bool EmptyFilter(FilterInfo &filter_info); }; diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp index 3b8fda1c6..13f037a69 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp @@ -56,7 +56,11 @@ class RelationManager { //! Extract the set of relations referred to inside an expression bool ExtractBindings(Expression &expression, unordered_set &bindings); void AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats); - + //! Add an unnest relation which can come from a logical unnest or a logical get which has an unnest function + void AddRelationWithChildren(JoinOrderOptimizer &optimizer, LogicalOperator &op, LogicalOperator &input_op, + optional_ptr parent, RelationStats &child_stats, + optional_ptr limit_op, + vector> &datasource_filters); void AddAggregateOrWindowRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats, LogicalOperatorType op_type); vector> GetRelations(); diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp index c1f2c4586..4f70323dc 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp @@ -20,7 +20,19 @@ struct DistinctCount { }; struct ExpressionBinding { - bool found_expression = false; +public: + bool FoundExpression() const { + return expression; + } + bool FoundColumnRef() const { + if (!FoundExpression()) { + return false; + } + return expression->type == ExpressionType::BOUND_COLUMN_REF; + } + +public: + optional_ptr expression; ColumnBinding child_binding; bool expression_is_constant = false; }; diff --git a/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp b/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp new file mode 100644 index 000000000..ca0589fb0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/late_materialization_helper.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/late_materialization_helper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +namespace duckdb { + +struct LateMaterializationHelper { + static unique_ptr CreateLHSGet(const LogicalGet &rhs, Binder &binder); + static vector GetOrInsertRowIds(LogicalGet &get, const vector &row_id_column_ids, + const vector &row_id_columns); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp index c2e5b1fc4..404e2bb73 100644 --- a/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/remove_unused_columns.hpp @@ -28,6 +28,8 @@ class BaseColumnPruner : public LogicalOperatorVisitor { //! The map of column references column_binding_map_t column_references; + vector deliver_child; + protected: void VisitExpression(unique_ptr *expression) override; @@ -47,6 +49,8 @@ class BaseColumnPruner : public LogicalOperatorVisitor { bool HandleStructExtractRecursive(Expression &expr, optional_ptr &colref, vector &indexes); + + bool HandleStructPack(Expression &expr); }; //! The RemoveUnusedColumns optimizer traverses the logical operator tree and removes any columns that are not required @@ -68,5 +72,6 @@ class RemoveUnusedColumns : public BaseColumnPruner { private: template void ClearUnusedExpressions(vector &list, idx_t table_idx, bool replace = true); + void RemoveColumnsFromLogicalGet(LogicalGet &get); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/row_group_pruner.hpp b/src/duckdb/src/include/duckdb/optimizer/row_group_pruner.hpp new file mode 100644 index 000000000..9ce7b5ee2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/row_group_pruner.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/row_group_pruner.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/table_function.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/storage/table/scan_state.hpp" + +namespace duckdb { +class LogicalGet; +class LogicalOperator; + +class RowGroupPruner { +public: + explicit RowGroupPruner(ClientContext &context); + + //! Reorder and try to prune row groups in queries with LIMIT or simple aggregates + unique_ptr Optimize(unique_ptr op); + //! Whether we can perform the optimization on this operator + bool TryOptimize(LogicalOperator &op) const; + +private: + ClientContext &context; + +private: + void GetLimitAndOffset(const LogicalLimit &logical_limit, optional_idx &row_limit, optional_idx &row_offset) const; + optional_ptr FindLogicalOrder(const LogicalLimit &logical_limit) const; + optional_ptr FindLogicalGet(const LogicalOrder &logical_order, column_t &column_index) const; + // row_limit, row_offset, primary_order, logical_get, logical_limit + unique_ptr CreateRowGroupReordererOptions(optional_idx row_limit, optional_idx row_offset, + const BoundOrderByNode &primary_order, + const LogicalGet &logical_get, + column_t column_index, + LogicalLimit &logical_limit) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp new file mode 100644 index 000000000..775ba01fb --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/rule/constant_order_normalization.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +// Move constant expression parameters to the left in expression(i.e. x + 2 + y + 2 => 2 + 2 + x + y) +// for convenience of other rules(i.e. ConstantFoldingRule). +class ConstantOrderNormalizationRule : public Rule { +public: + explicit ConstantOrderNormalizationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp index 4b0099cd8..312fa7d93 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/list.hpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" #include "duckdb/optimizer/rule/case_simplification.hpp" +#include "duckdb/optimizer/rule/constant_order_normalization.hpp" #include "duckdb/optimizer/rule/comparison_simplification.hpp" #include "duckdb/optimizer/rule/conjunction_simplification.hpp" #include "duckdb/optimizer/rule/constant_folding.hpp" diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp index 1b757e78c..6330b4144 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp @@ -10,6 +10,7 @@ #include "duckdb/optimizer/rule.hpp" #include "duckdb/parser/expression_map.hpp" +#include "duckdb/parser/group_by_node.hpp" namespace duckdb { @@ -18,7 +19,8 @@ class OrderedAggregateOptimizer : public Rule { explicit OrderedAggregateOptimizer(ExpressionRewriter &rewriter); static unique_ptr Apply(ClientContext &context, BoundAggregateExpression &aggr, - vector> &groups, bool &changes_made); + vector> &groups, + optional_ptr> grouping_sets, bool &changes_made); unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) override; }; diff --git a/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp new file mode 100644 index 000000000..0fb5ba00c --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/topn_window_elimination.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/remove_unused_columns.hpp" + +namespace duckdb { + +enum class TopNPayloadType { SINGLE_COLUMN, STRUCT_PACK }; + +struct TopNWindowEliminationParameters { + //! Whether the sort is ASCENDING or DESCENDING + OrderType order_type; + //! The number of values in the LIMIT clause + int64_t limit; + //! How we fetch the payload columns + TopNPayloadType payload_type; + //! Whether to include row numbers + bool include_row_number; + //! Whether the val or arg column contains null values + bool can_be_null = false; +}; + +class TopNWindowElimination : public BaseColumnPruner { +public: + explicit TopNWindowElimination(ClientContext &context, Optimizer &optimizer, + optional_ptr>> stats_p); + + unique_ptr Optimize(unique_ptr op); + +private: + bool CanOptimize(LogicalOperator &op); + unique_ptr OptimizeInternal(unique_ptr op, ColumnBindingReplacer &replacer); + + unique_ptr CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const; + + vector> GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, map &group_idxs); + vector TraverseProjectionBindings(const std::vector &old_bindings, + reference &op); + unique_ptr CreateAggregateExpression(vector> aggregate_params, bool requires_arg, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const; + void AddStructExtractExprs(vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const; + static void UpdateTopmostBindings(idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, ColumnBindingReplacer &replacer); + TopNWindowEliminationParameters ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload); + + // Semi-join reduction methods + unique_ptr TryPrepareLateMaterialization(const LogicalWindow &window, + vector> &args); + unique_ptr ConstructLHS(LogicalGet &rhs, vector &projections) const; + static unique_ptr ConstructJoin(unique_ptr lhs, unique_ptr rhs, + idx_t rhs_rowid_idx, + const TopNWindowEliminationParameters ¶ms); + bool CanUseLateMaterialization(const LogicalWindow &window, vector> &args, + vector &projections, vector> &stack); + +private: + ClientContext &context; + Optimizer &optimizer; + optional_ptr>> stats; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp index 842d5ef2c..e1b14724a 100644 --- a/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp @@ -61,7 +61,8 @@ class UnnestRewriter { private: //! Find delim joins that contain an UNNEST - void FindCandidates(unique_ptr &op, vector>> &candidates); + void FindCandidates(unique_ptr &root, unique_ptr &op, + vector>> &candidates); //! Rewrite a delim join that contains an UNNEST bool RewriteCandidate(unique_ptr &candidate); //! Update the bindings of the RHS sequence of LOGICAL_PROJECTION(s) diff --git a/src/duckdb/src/include/duckdb/parallel/async_result.hpp b/src/duckdb/src/include/duckdb/parallel/async_result.hpp new file mode 100644 index 000000000..97ede1cbc --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/async_result.hpp @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/async_result.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/operator_result_type.hpp" + +namespace duckdb { + +class InterruptState; +class TaskExecutor; +class Executor; + +enum class AsyncResultsExecutionMode : uint8_t { + SYNCHRONOUS, // BLOCKED should not bubble up, and they should be executed synchronously + TASK_EXECUTOR // BLOCKED is allowed +}; + +class AsyncTask { +public: + virtual ~AsyncTask() {}; + virtual void Execute() = 0; +}; + +class AsyncResult { + explicit AsyncResult(AsyncResultType t); + +public: + AsyncResult() = default; + AsyncResult(AsyncResult &&) = default; + AsyncResult(SourceResultType t); // NOLINT + explicit AsyncResult(vector> &&task); + AsyncResult &operator=(SourceResultType t); + AsyncResult &operator=(AsyncResultType t); + AsyncResult &operator=(AsyncResult &&) noexcept; + // Schedule held async_tasks into the Executor, eventually unblocking InterruptState + // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + // into INVALID + void ScheduleTasks(InterruptState &interrupt_state, Executor &executor); + // Execute tasks synchronously at callsite + // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + // into HAVE_MORE_OUTPUT + void ExecuteTasksSynchronously(); + + static AsyncResultType GetAsyncResultType(SourceResultType s); + + // Check whether there are tasks associated + bool HasTasks() const; + AsyncResultType GetResultType() const; + // Extract associated tasks, moving them away, will empty async_tasks and trasnform to INVALID + vector> &&ExtractAsyncTasks(); + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE + static vector> GenerateTestTasks(); +#endif + + static AsyncResultsExecutionMode + ConvertToAsyncResultExecutionMode(const PhysicalTableScanExecutionStrategy &execution_mode); + +private: + AsyncResultType result_type {AsyncResultType::INVALID}; + vector> async_tasks {}; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp index ef0bf8139..b2db8497e 100644 --- a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp +++ b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp @@ -83,6 +83,11 @@ class StateWithBlockableTasks { return false; } + bool CanBlock(const unique_lock &guard) const { + VerifyLock(guard); + return can_block; + } + //! Unblock all tasks (must hold the lock) bool UnblockTasks(const unique_lock &guard) { VerifyLock(guard); diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp index 9781e6fb8..73cd38895 100644 --- a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp @@ -112,8 +112,8 @@ class PipelineExecutor { //! Partition info that is used by this executor OperatorPartitionInfo required_partition_info; - //! Source has indicated it is exhausted - bool exhausted_source = false; + //! Source or intermediate operator indicated that there is no more output possible + bool exhausted_pipeline = false; //! Flushing of intermediate operators has started bool started_flushing = false; //! Flushing of caching operators is done @@ -152,7 +152,7 @@ class PipelineExecutor { OperatorResultType Execute(DataChunk &input, DataChunk &result, idx_t initial_index = 0); //! Notifies the sink that a new batch has started - SinkNextBatchType NextBatch(DataChunk &source_chunk); + SinkNextBatchType NextBatch(DataChunk &source_chunk, const bool have_more_output); //! Tries to flush all state from intermediate operators. Will return true if all state is flushed, false in the //! case of a blocked sink. diff --git a/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp b/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp new file mode 100644 index 000000000..f53fc1d4f --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/sleep_async_task.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/sleep_async_task.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parallel/async_result.hpp" + +#include +#include + +namespace duckdb { + +class SleepAsyncTask : public AsyncTask { +public: + explicit SleepAsyncTask(idx_t sleep_for) : sleep_for(sleep_for) { + } + void Execute() override { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_for)); + } + const idx_t sleep_for; +}; + +class ThrowAsyncTask : public AsyncTask { +public: + explicit ThrowAsyncTask(idx_t sleep_for) : sleep_for(sleep_for) { + } + void Execute() override { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_for)); + throw NotImplementedException("ThrowAsyncTask: Test error handling when throwing mid-task"); + } + const idx_t sleep_for; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/base_expression.hpp b/src/duckdb/src/include/duckdb/parser/base_expression.hpp index 7e339e2b9..9d55b6b40 100644 --- a/src/duckdb/src/include/duckdb/parser/base_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/base_expression.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" #include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/optional_idx.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/column_definition.hpp b/src/duckdb/src/include/duckdb/parser/column_definition.hpp index fa5cc8d38..c85a2d689 100644 --- a/src/duckdb/src/include/duckdb/parser/column_definition.hpp +++ b/src/duckdb/src/include/duckdb/parser/column_definition.hpp @@ -13,7 +13,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/common/enums/compression_type.hpp" #include "duckdb/catalog/catalog_entry/table_column_type.hpp" -#include "duckdb/common/case_insensitive_map.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp index 552753692..7a784eeb1 100644 --- a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp @@ -16,16 +16,20 @@ namespace duckdb { class SelectStatement; struct CommonTableExpressionInfo { + ~CommonTableExpressionInfo(); + vector aliases; vector> key_targets; unique_ptr query; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; +public: void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); unique_ptr Copy(); - ~CommonTableExpressionInfo(); +private: + CTEMaterialize GetMaterializedForSerialization(Serializer &serializer) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraint.hpp index 1c497b1ed..22c89a123 100644 --- a/src/duckdb/src/include/duckdb/parser/constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraint.hpp @@ -8,7 +8,7 @@ #pragma once -#include "duckdb/common/common.hpp" +#include "duckdb/common/constants.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/exception.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp index cba2a61fb..d5ea83fd5 100644 --- a/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraints/check_constraint.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/string_util.hpp" #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/parsed_expression.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp index ded23b0d3..dd73a87a0 100644 --- a/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp @@ -8,8 +8,6 @@ #pragma once -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/enums/index_constraint_type.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/parser/column_list.hpp" #include "duckdb/parser/constraint.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp index 2e12da048..41ea7d713 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/vector.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { //! Represents a function call diff --git a/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp index 3551e0982..9ef98bc4b 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp @@ -11,8 +11,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/parser/qualified_name.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" namespace duckdb { //! Represents a built-in operator expression diff --git a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp index cc5e9c61f..115b0d268 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp @@ -13,6 +13,11 @@ namespace duckdb { +struct WindowFunctionDefinition { + const char *name; + ExpressionType expression_type; +}; + enum class WindowBoundary : uint8_t { INVALID = 0, UNBOUNDED_PRECEDING = 1, @@ -57,9 +62,9 @@ class WindowExpression : public ParsedExpression { //! Expression representing a filter, only used for aggregates unique_ptr filter_expr; //! True to ignore NULL values - bool ignore_nulls; + bool ignore_nulls = false; //! Whether or not the aggregate function is distinct, only used for aggregates - bool distinct; + bool distinct = false; //! The window boundaries WindowBoundary start = WindowBoundary::INVALID; WindowBoundary end = WindowBoundary::INVALID; @@ -92,6 +97,7 @@ class WindowExpression : public ParsedExpression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + static const WindowFunctionDefinition *WindowFunctions(); static ExpressionType WindowToExpressionType(string &fun_name); public: diff --git a/src/duckdb/src/include/duckdb/parser/expression_map.hpp b/src/duckdb/src/include/duckdb/parser/expression_map.hpp index 75ecd78e8..406809689 100644 --- a/src/duckdb/src/include/duckdb/parser/expression_map.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression_map.hpp @@ -10,7 +10,6 @@ #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" -#include "duckdb/parser/base_expression.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/planner/expression.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp index 7dcf83074..ddf31b265 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp index f3a209462..812d75d82 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_scalar_function_info.hpp @@ -8,8 +8,6 @@ #pragma once -#include "duckdb/function/function_set.hpp" -#include "duckdb/function/scalar_function.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp index fc7944c3f..468fbb405 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_function_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/function/function_set.hpp" -#include "duckdb/function/table_function.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp index b206e9143..1408d6e28 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp @@ -11,7 +11,6 @@ #include "duckdb/parser/parsed_data/alter_info.hpp" #include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/constraint.hpp" -#include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/result_modifier.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/bound_pragma_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/bound_pragma_info.hpp index 763e98a2b..90f7c27b3 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/bound_pragma_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/bound_pragma_info.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_data/pragma_info.hpp" #include "duckdb/function/pragma_function.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp index 60aa36ef0..cbaf8e7ba 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp @@ -11,7 +11,6 @@ #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" -#include "duckdb/parser/qualified_name.hpp" namespace duckdb { class CatalogEntryRetriever; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp index fb6522cc7..10f7e271a 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/parsed_data/create_info.hpp" -#include "duckdb/function/function.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp index 2e8d9a73d..272ac378f 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/parsed_data/create_function_info.hpp" -#include "duckdb/function/pragma_function.hpp" #include "duckdb/function/function_set.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp index 12aee6052..bec1569c8 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_scalar_function_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/parsed_data/create_function_info.hpp" -#include "duckdb/function/scalar_function.hpp" #include "duckdb/function/function_set.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp index 76c4353b6..fe345d6e8 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp @@ -9,11 +9,8 @@ #pragma once #include "duckdb/main/secret/secret.hpp" -#include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/parsed_data/create_info.hpp" -#include "duckdb/common/enums/on_entry_not_found.hpp" #include "duckdb/common/named_parameter_map.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp index ebb1fb452..1cd1ffb0b 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/parsed_data/create_info.hpp" -#include "duckdb/common/limits.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp index 20a8bd788..84423dc59 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp @@ -9,11 +9,8 @@ #pragma once #include "duckdb/parser/parsed_data/create_info.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/statement/select_statement.hpp" -#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" #include "duckdb/parser/column_list.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp index 6b50064a1..2c4eb983f 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp @@ -9,8 +9,6 @@ #pragma once #include "duckdb/parser/parsed_data/create_info.hpp" -#include "duckdb/parser/column_definition.hpp" -#include "duckdb/parser/constraint.hpp" #include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/extra_drop_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/extra_drop_info.hpp index 2812469de..b2d952042 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/extra_drop_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/extra_drop_info.hpp @@ -11,7 +11,6 @@ #include "duckdb/main/secret/secret.hpp" #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" -#include "duckdb/common/enums/on_entry_not_found.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp index bbf01242e..c76b2fcbf 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/types/value.hpp" -#include "duckdb/common/named_parameter_map.hpp" #include "duckdb/parser/parsed_expression.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp index dadbcfe92..b9cb0841c 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp @@ -8,10 +8,8 @@ #pragma once -#include "duckdb/common/common.hpp" -#include "duckdb/parser/parsed_expression.hpp" -#include "duckdb/common/vector.hpp" #include "duckdb/common/types/value.hpp" +#include "duckdb/common/optional_idx.hpp" namespace duckdb { @@ -23,6 +21,9 @@ enum class SampleMethod : uint8_t { SYSTEM_SAMPLE = 0, BERNOULLI_SAMPLE = 1, RES string SampleMethodToString(SampleMethod method); class SampleOptions { +public: + // 1 billion rows should be enough. + static constexpr idx_t MAX_SAMPLE_ROWS = 1000000000; public: explicit SampleOptions(int64_t seed_ = -1); diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp index 5540d38a2..9b5ce8812 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp @@ -10,10 +10,6 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/optional_ptr.hpp" -#include "duckdb/catalog/dependency_list.hpp" namespace duckdb { class Serializer; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp index 46acfe9e6..3421bca88 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp @@ -10,9 +10,7 @@ #include "duckdb/parser/base_expression.hpp" #include "duckdb/common/vector.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/parser/qualified_name.hpp" -#include "duckdb/parser/expression_util.hpp" namespace duckdb { class Deserializer; diff --git a/src/duckdb/src/include/duckdb/parser/parser.hpp b/src/duckdb/src/include/duckdb/parser/parser.hpp index ce373fe9c..37b151236 100644 --- a/src/duckdb/src/include/duckdb/parser/parser.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser.hpp @@ -14,6 +14,8 @@ #include "duckdb/parser/column_list.hpp" #include "duckdb/parser/simplified_token.hpp" #include "duckdb/parser/parser_options.hpp" +#include "duckdb/common/exception/parser_exception.hpp" +#include "duckdb/parser/parser_extension.hpp" namespace duckdb_libpgquery { struct PGNode; @@ -73,6 +75,9 @@ class Parser { static bool StripUnicodeSpaces(const string &query_str, string &new_query); + unique_ptr GetStatement(const string &query); + void ThrowParserOverrideError(ParserOverrideResult &result); + private: ParserOptions options; }; diff --git a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp index 61c071307..cf264adc7 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp @@ -67,7 +67,7 @@ struct ParserExtensionPlanResult { // NOLINT: work-around bug in clang-tidy //! Parameters to the function vector parameters; //! The set of databases that will be modified by this statement (empty for a read-only statement) - unordered_map modified_databases; + unordered_map modified_databases; //! Whether or not the statement requires a valid transaction to be executed bool requires_valid_transaction = true; //! What type of result set the statement returns @@ -86,12 +86,12 @@ struct ParserOverrideResult { explicit ParserOverrideResult(vector> statements_p) : type(ParserExtensionResultType::PARSE_SUCCESSFUL), statements(std::move(statements_p)) {}; - explicit ParserOverrideResult(const string &error_p) + explicit ParserOverrideResult(std::exception &error_p) : type(ParserExtensionResultType::DISPLAY_EXTENSION_ERROR), error(error_p) {}; ParserExtensionResultType type; vector> statements; - string error; + ErrorData error; }; typedef ParserOverrideResult (*parser_override_function_t)(ParserExtensionInfo *info, const string &query); @@ -103,14 +103,14 @@ class ParserExtension { public: //! The parse function of the parser extension. //! Takes a query string as input and returns ParserExtensionParseData (on success) or an error - parse_function_t parse_function; + parse_function_t parse_function = nullptr; //! The plan function of the parser extension //! Takes as input the result of the parse_function, and outputs various properties of the resulting plan - plan_function_t plan_function; + plan_function_t plan_function = nullptr; //! Override the current parser with a new parser and return a vector of SQL statements - parser_override_function_t parser_override; + parser_override_function_t parser_override = nullptr; //! Additional parser info passed to the parse function shared_ptr parser_info; diff --git a/src/duckdb/src/include/duckdb/parser/parser_options.hpp b/src/duckdb/src/include/duckdb/parser/parser_options.hpp index d388fb116..d9a42632a 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_options.hpp @@ -18,6 +18,7 @@ struct ParserOptions { bool integer_division = false; idx_t max_expression_depth = 1000; const vector *extensions = nullptr; + string parser_override_setting = "default"; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp index 25f88b0df..38f3a43cb 100644 --- a/src/duckdb/src/include/duckdb/parser/qualified_name.hpp +++ b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp @@ -9,10 +9,8 @@ #pragma once #include "duckdb/common/string.hpp" -#include "duckdb/common/exception/parser_exception.hpp" -#include "duckdb/parser/keyword_helper.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/planner/binding_alias.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp index c17cde85d..5dd28c78c 100644 --- a/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp +++ b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/qualified_name.hpp" -#include "duckdb/common/types/hash.hpp" #include "duckdb/common/unordered_set.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/query_error_context.hpp b/src/duckdb/src/include/duckdb/parser/query_error_context.hpp index f0de3fa11..ccfe709ca 100644 --- a/src/duckdb/src/include/duckdb/parser/query_error_context.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_error_context.hpp @@ -8,9 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/common/exception_format_value.hpp" #include "duckdb/common/optional_idx.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp index ec03da095..4981e8afc 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/result_modifier.hpp" #include "duckdb/parser/common_table_expression_info.hpp" @@ -25,7 +24,8 @@ enum class QueryNodeType : uint8_t { SET_OPERATION_NODE = 2, BOUND_SUBQUERY_NODE = 3, RECURSIVE_CTE_NODE = 4, - CTE_NODE = 5 + CTE_NODE = 5, + STATEMENT_NODE = 6 }; struct CommonTableExpressionInfo; @@ -59,8 +59,6 @@ class QueryNode { //! CTEs (used by SelectNode and SetOperationNode) CommonTableExpressionMap cte_map; - virtual const vector> &GetSelectList() const = 0; - public: //! Convert the query node to a string virtual string ToString() const = 0; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp index bc997a6c7..3bca2cd06 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -10,10 +10,10 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/sql_statement.hpp" namespace duckdb { +//! DEPRECATED - CTENode is only preserved for backwards compatibility when serializing older databases class CTENode : public QueryNode { public: static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; @@ -23,30 +23,18 @@ class CTENode : public QueryNode { } string ctename; - //! The query of the CTE unique_ptr query; - //! Child unique_ptr child; - //! Aliases of the CTE node vector aliases; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - const vector> &GetSelectList() const override { - return query->GetSelectList(); - } - public: - //! Convert the query node to a string string ToString() const override; bool Equals(const QueryNode *other) const override; - //! Create a copy of this SelectNode unique_ptr Copy() const override; - //! Serializes a QueryNode to a stand-alone binary blob - //! Deserializes a blob back into a QueryNode - void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &source); }; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp index 94bfd3438..3a2894cc4 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp @@ -2,3 +2,4 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/statement_node.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp index 6d73fda4a..e4c8c619c 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/sql_statement.hpp" namespace duckdb { @@ -33,10 +32,6 @@ class RecursiveCTENode : public QueryNode { //! targets for key variants vector> key_targets; - const vector> &GetSelectList() const override { - return left->GetSelectList(); - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp index 62aa9c0b2..38b6ae7b1 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" @@ -43,10 +42,6 @@ class SelectNode : public QueryNode { //! The SAMPLE clause unique_ptr sample; - const vector> &GetSelectList() const override { - return select_list; - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp index 960f6c2d6..2dfb63cdb 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp @@ -11,7 +11,6 @@ #include "duckdb/common/enums/set_operation_type.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/sql_statement.hpp" namespace duckdb { @@ -29,8 +28,6 @@ class SetOperationNode : public QueryNode { //! The children of the set operation vector> children; - const vector> &GetSelectList() const override; - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp new file mode 100644 index 000000000..26db46a58 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/statement_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class StatementNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::STATEMENT_NODE; + +public: + explicit StatementNode(SQLStatement &stmt_p); + + SQLStatement &stmt; + +public: + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/result_modifier.hpp b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp index f1d796527..8e777c082 100644 --- a/src/duckdb/src/include/duckdb/parser/result_modifier.hpp +++ b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/enums/order_type.hpp" #include "duckdb/parser/parsed_expression.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/simplified_token.hpp b/src/duckdb/src/include/duckdb/parser/simplified_token.hpp index 7a50824d0..ff073a3cf 100644 --- a/src/duckdb/src/include/duckdb/parser/simplified_token.hpp +++ b/src/duckdb/src/include/duckdb/parser/simplified_token.hpp @@ -21,7 +21,9 @@ enum class SimplifiedTokenType : uint8_t { SIMPLIFIED_TOKEN_OPERATOR, SIMPLIFIED_TOKEN_KEYWORD, SIMPLIFIED_TOKEN_COMMENT, - SIMPLIFIED_TOKEN_ERROR + SIMPLIFIED_TOKEN_ERROR, + SIMPLIFIED_TOKEN_ERROR_EMPHASIS, + SIMPLIFIED_TOKEN_ERROR_SUGGESTION }; struct SimplifiedToken { diff --git a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp index b9d03a07e..2e8278d64 100644 --- a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" #include "duckdb/common/enums/statement_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/printer.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp index afd5d256a..97be56e05 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/call_statement.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" -#include "duckdb/common/vector.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp index d206bc5f6..a70ca1193 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp @@ -8,8 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_data/copy_info.hpp" -#include "duckdb/parser/query_node.hpp" #include "duckdb/parser/sql_statement.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp index 2a15165dc..297ab031a 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" -#include "duckdb/common/vector.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp index 7a30b6b64..0ab5acec8 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/explain_statement.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" #include "duckdb/common/enums/explain_format.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp index 97430a930..60f2f88c6 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/export_statement.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp index 249c468b3..3415eb31a 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/pragma_statement.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/parsed_data/pragma_info.hpp" -#include "duckdb/parser/parsed_expression.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp index c2be75989..c47cb60f0 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp index 5d7295ca3..a7352870c 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp @@ -14,6 +14,7 @@ namespace duckdb { +class QueryNode; class Serializer; class Deserializer; diff --git a/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp index 1818820db..f7051f893 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp @@ -11,7 +11,6 @@ #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/common/enums/set_type.hpp" #include "duckdb/parser/sql_statement.hpp" -#include "duckdb/common/types/value.hpp" #include "duckdb/parser/parsed_expression.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/update_extensions_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/update_extensions_statement.hpp index 2c361b336..479ec5156 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/update_extensions_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/update_extensions_statement.hpp @@ -8,11 +8,7 @@ #pragma once -#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" -#include "duckdb/parser/tableref.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/parser/query_node.hpp" #include "duckdb/parser/parsed_data/update_extensions_info.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp index 07b04d1cd..fb77b514b 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/vacuum_statement.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp index e5b2d6d2f..b2339d9bc 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/vector.hpp" #include "duckdb/main/table_description.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/tableref/at_clause.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp index be997eb5d..c8586a559 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/binder.hpp" namespace duckdb { @@ -20,10 +19,10 @@ class BoundRefWrapper : public TableRef { static constexpr const TableReferenceType TYPE = TableReferenceType::BOUND_TABLE_REF; public: - BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p); + BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p); //! The bound reference object - unique_ptr bound_ref; + BoundStatement bound_ref; //! The binder that was used to bind this table ref shared_ptr binder; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp index ea6889362..1cf31d1eb 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp @@ -13,7 +13,6 @@ namespace duckdb { class DelimGetRef : public TableRef { - public: explicit DelimGetRef(const vector &types_p) : TableRef(TableReferenceType::DELIM_GET), types(types_p) { for (idx_t i = 0; i < types.size(); i++) { diff --git a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp index 6aa6e8015..f31992aef 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp @@ -10,7 +10,6 @@ #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/joinref_type.hpp" -#include "duckdb/common/unordered_set.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/common/vector.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp index 77a37d444..163bf66ad 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp @@ -10,8 +10,6 @@ #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/parsed_expression.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/vector.hpp" #include "duckdb/parser/query_node.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp index c76377991..e20a8eb49 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/tableref.hpp" -#include "duckdb/common/vector.hpp" #include "duckdb/parser/statement/select_statement.hpp" #include "duckdb/common/enums/ordinality_request_type.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/tokens.hpp b/src/duckdb/src/include/duckdb/parser/tokens.hpp index 6eeb8c5e2..d5646739c 100644 --- a/src/duckdb/src/include/duckdb/parser/tokens.hpp +++ b/src/duckdb/src/include/duckdb/parser/tokens.hpp @@ -53,6 +53,7 @@ class SelectNode; class SetOperationNode; class RecursiveCTENode; class CTENode; +class StatementNode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/duckdb/src/include/duckdb/parser/transformer.hpp b/src/duckdb/src/include/duckdb/parser/transformer.hpp index 59e4f0419..2afa96722 100644 --- a/src/duckdb/src/include/duckdb/parser/transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/transformer.hpp @@ -10,7 +10,6 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/constants.hpp" -#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/stack_checker.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/unordered_map.hpp" @@ -19,7 +18,6 @@ #include "duckdb/parser/parsed_data/create_secret_info.hpp" #include "duckdb/parser/qualified_name.hpp" #include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/tokens.hpp" #include "nodes/parsenodes.hpp" #include "nodes/primnodes.hpp" @@ -80,7 +78,7 @@ class Transformer { //! The set of pivot entries to create vector> pivot_entries; //! Sets of stored CTEs, if any - vector stored_cte_map; + vector> stored_cte_map; //! Whether or not we are currently binding a window definition bool in_window_definition = false; @@ -304,7 +302,6 @@ class Transformer { string TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias); vector TransformStringList(duckdb_libpgquery::PGList *list); void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map); - static unique_ptr TransformMaterializedCTE(unique_ptr root); unique_ptr TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node, CommonTableExpressionInfo &info); diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index d9c20dd1d..db5b52c78 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -23,7 +23,7 @@ namespace duckdb { class Binder; class LogicalGet; -class BoundQueryNode; +struct BoundStatement; class StarExpression; @@ -43,9 +43,6 @@ class BindContext { public: explicit BindContext(Binder &binder); - //! Keep track of recursive CTE references - case_insensitive_map_t> cte_references; - public: //! Given a column name, find the matching table it belongs to. Throws an //! exception if no table has a column of the given name. @@ -57,7 +54,7 @@ class BindContext { //! matching ones vector GetSimilarBindings(const string &column_name); - optional_ptr GetCTEBinding(const string &ctename); + optional_ptr GetCTEBinding(const BindingAlias &ctename); //! Binds a column expression to the base table. Returns the bound expression //! or throws an exception if the column could not be bound. BindResult BindColumn(ColumnRefExpression &colref, idx_t depth); @@ -105,11 +102,11 @@ class BindContext { const vector &types, vector &bound_column_ids, optional_ptr entry, virtual_column_map_t virtual_columns); //! Adds a table view with a given alias to the BindContext. - void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, ViewCatalogEntry &view); + void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery); //! Adds a binding to a catalog entry with a given alias to the BindContext. void AddEntryBinding(idx_t index, const string &alias, const vector &names, const vector &types, StandardEntry &entry); @@ -119,10 +116,9 @@ class BindContext { //! Adds a base table with the given alias to the CTE BindContext. //! We need this to correctly bind recursive CTEs with multiple references. - void AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types, - bool using_key = false); - - void RemoveCTEBinding(const string &alias); + void AddCTEBinding(idx_t index, BindingAlias alias, const vector &names, const vector &types, + CTEType cte_type = CTEType::CAN_BE_REFERENCED); + void AddCTEBinding(unique_ptr binding); //! Add an implicit join condition (e.g. USING (x)) void AddUsingBinding(const string &column_name, UsingColumnSet &set); @@ -146,13 +142,6 @@ class BindContext { string GetActualColumnName(const BindingAlias &binding_alias, const string &column_name); string GetActualColumnName(Binding &binding, const string &column_name); - case_insensitive_map_t> GetCTEBindings() { - return cte_bindings; - } - void SetCTEBindings(case_insensitive_map_t> bindings) { - cte_bindings = std::move(bindings); - } - //! Alias a set of column names for the specified table, using the original names if there are not enough aliases //! specified. static vector AliasColumnNames(const string &table_name, const vector &names, @@ -184,10 +173,7 @@ class BindContext { vector> bindings_list; //! The set of columns used in USING join conditions case_insensitive_map_t> using_columns; - //! Using column sets - vector> using_column_sets; - //! The set of CTE bindings - case_insensitive_map_t> cte_bindings; + vector> cte_bindings; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 5a664f2dc..523ba484a 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -27,7 +27,6 @@ #include "duckdb/planner/joinside.hpp" #include "duckdb/planner/bound_constraint.hpp" #include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" #include "duckdb/common/enums/copy_option_mode.hpp" //! fwd declare @@ -69,7 +68,9 @@ struct PivotColumnEntry; struct UnpivotEntry; struct CopyInfo; struct CopyOption; - +struct BoundSetOpChild; +struct BoundCTEData; +enum class CopyToType : uint8_t; template class IndexVector; @@ -100,6 +101,89 @@ struct CorrelatedColumnInfo { } }; +struct CorrelatedColumns { +private: + using container_type = vector; + +public: + CorrelatedColumns() : delim_index(1ULL << 63) { + } + + void AddColumn(container_type::value_type info) { + // Add to beginning + correlated_columns.insert(correlated_columns.begin(), std::move(info)); + delim_index++; + } + + void SetDelimIndexToZero() { + delim_index = 0; + } + + idx_t GetDelimIndex() const { + return delim_index; + } + + const container_type::value_type &operator[](const idx_t &index) const { + return correlated_columns.at(index); + } + + idx_t size() const { // NOLINT: match stl case + return correlated_columns.size(); + } + + bool empty() const { // NOLINT: match stl case + return correlated_columns.empty(); + } + + void clear() { // NOLINT: match stl case + correlated_columns.clear(); + } + + container_type::iterator begin() { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::iterator end() { // NOLINT: match stl case + return correlated_columns.end(); + } + + container_type::const_iterator begin() const { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::const_iterator end() const { // NOLINT: match stl case + return correlated_columns.end(); + } + +private: + container_type correlated_columns; + idx_t delim_index; +}; + +//! GlobalBinderState is state shared over the ENTIRE query, including subqueries, views, etc +struct GlobalBinderState { + //! The count of bound_tables + idx_t bound_tables = 0; + //! Statement properties + StatementProperties prop; + //! Binding mode + BindingMode mode = BindingMode::STANDARD_BINDING; + //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. + unordered_set table_names; + //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS + case_insensitive_map_t> replacement_scans; + //! Using column sets + vector> using_column_sets; + //! The set of parameter expressions bound by this binder + optional_ptr parameters; +}; + +// QueryBinderState is state shared WITHIN a query, a new query-binder state is created when binding inside e.g. a view +struct QueryBinderState { + //! The vector of active binders + vector> active_binders; +}; + //! Bind the parsed query tree to the actual columns present in the catalog. /*! The binder is responsible for binding tables and columns to actual physical @@ -116,15 +200,11 @@ class Binder : public enable_shared_from_this { //! The client context ClientContext &context; - //! A mapping of names to common table expressions - case_insensitive_set_t CTE_bindings; // NOLINT //! The bind context BindContext bind_context; //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a //! vector) - vector correlated_columns; - //! The set of parameter expressions bound by this binder - optional_ptr parameters; + CorrelatedColumns correlated_columns; //! The alias for the currently processing subquery, if it exists string alias; //! Macro parameter bindings (if any) @@ -171,8 +251,7 @@ class Binder : public enable_shared_from_this { QueryErrorContext &error_context, string &func_name); unique_ptr BindPragma(PragmaInfo &info, QueryErrorContext error_context); - unique_ptr Bind(TableRef &ref); - unique_ptr CreatePlan(BoundTableRef &ref); + BoundStatement Bind(TableRef &ref); //! Generates an unused index for a table idx_t GenerateTableIndex(); @@ -180,12 +259,8 @@ class Binder : public enable_shared_from_this { optional_ptr GetCatalogEntry(const string &catalog, const string &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found); - //! Add a common table expression to the binder - void AddCTE(const string &name); //! Find all candidate common table expression by name; returns empty vector if none exists - vector> FindCTE(const string &name, bool skip = false); - - bool CTEExists(const string &name); + optional_ptr GetCTEBinding(const BindingAlias &name); //! Add the view to the set of currently bound views - used for detecting recursive view definitions void AddBoundView(ViewCatalogEntry &view); @@ -198,7 +273,7 @@ class Binder : public enable_shared_from_this { vector> &GetActiveBinders(); - void MergeCorrelatedColumns(vector &other); + void MergeCorrelatedColumns(CorrelatedColumns &other); //! Add a correlated column to this binder (if it does not exist) void AddCorrelatedColumn(const CorrelatedColumnInfo &info); @@ -228,12 +303,11 @@ class Binder : public enable_shared_from_this { void AddReplacementScan(const string &table_name, unique_ptr replacement); const unordered_set &GetTableNames(); case_insensitive_map_t> &GetReplacementScans(); - optional_ptr GetRootStatement() { - return root_statement; - } CatalogEntryRetriever &EntryRetriever() { return entry_retriever; } + optional_ptr GetParameters(); + void SetParameters(BoundParameterMap ¶meters); //! Returns a ColumnRefExpression after it was resolved (i.e. past the STAR expression/USING clauses) static optional_ptr GetResolvedColumnExpression(ParsedExpression &root_expr); @@ -250,42 +324,28 @@ class Binder : public enable_shared_from_this { private: //! The parent binder (if any) shared_ptr parent; - //! The vector of active binders - vector> active_binders; - //! The count of bound_tables - idx_t bound_tables; + //! What kind of node we are binding using this binder + BinderType binder_type = BinderType::REGULAR_BINDER; + //! Global binder state + shared_ptr global_binder_state; + //! Query binder state + shared_ptr query_binder_state; //! Whether or not the binder has any unplanned dependent joins that still need to be planned/flattened bool has_unplanned_dependent_joins = false; //! Whether or not outside dependent joins have been planned and flattened bool is_outside_flattened = true; - //! What kind of node we are binding using this binder - BinderType binder_type = BinderType::REGULAR_BINDER; //! Whether or not the binder can contain NULLs as the root of expressions bool can_contain_nulls = false; - //! The root statement of the query that is currently being parsed - optional_ptr root_statement; - //! Binding mode - BindingMode mode = BindingMode::STANDARD_BINDING; - //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. - unordered_set table_names; - //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS - case_insensitive_map_t> replacement_scans; //! The set of bound views reference_set_t bound_views; //! Used to retrieve CatalogEntry's CatalogEntryRetriever entry_retriever; //! Unnamed subquery index idx_t unnamed_subquery_index = 1; - //! Statement properties - StatementProperties prop; - //! Root binder - Binder &root_binder; //! Binder depth idx_t depth; private: - //! Get the root binder (binder with no parent) - Binder &GetRootBinder(); //! Determine the depth of the binder idx_t GetBinderDepth() const; //! Increase the depth of the binder @@ -303,7 +363,7 @@ class Binder : public enable_shared_from_this { void MoveCorrelatedExpressions(Binder &other); //! Tries to bind the table name with replacement scans - unique_ptr BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); + BoundStatement BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); template BoundStatement BindWithCTE(T &statement); @@ -344,41 +404,39 @@ class Binder : public enable_shared_from_this { unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - unique_ptr BindMaterializedCTE(CommonTableExpressionMap &cte_map); - unique_ptr BindCTE(CTENode &statement); + BoundStatement BindCTE(const string &ctename, CommonTableExpressionInfo &info); - unique_ptr BindNode(SelectNode &node); - unique_ptr BindNode(SetOperationNode &node); - unique_ptr BindNode(RecursiveCTENode &node); - unique_ptr BindNode(CTENode &node); - unique_ptr BindNode(QueryNode &node); + BoundStatement BindNode(SelectNode &node); + BoundStatement BindNode(SetOperationNode &node); + BoundStatement BindNode(RecursiveCTENode &node); + BoundStatement BindNode(QueryNode &node); + BoundStatement BindNode(StatementNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); - unique_ptr CreatePlan(BoundRecursiveCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node, unique_ptr base); unique_ptr CreatePlan(BoundSelectNode &statement); unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); - unique_ptr BindJoin(Binder &parent, TableRef &ref); - unique_ptr Bind(BaseTableRef &ref); - unique_ptr Bind(BoundRefWrapper &ref); - unique_ptr Bind(JoinRef &ref); - unique_ptr Bind(SubqueryRef &ref); - unique_ptr Bind(TableFunctionRef &ref); - unique_ptr Bind(EmptyTableRef &ref); - unique_ptr Bind(DelimGetRef &ref); - unique_ptr Bind(ExpressionListRef &ref); - unique_ptr Bind(ColumnDataRef &ref); - unique_ptr Bind(PivotRef &expr); - unique_ptr Bind(ShowRef &ref); + void BuildUnionByNameInfo(BoundSetOperationNode &result); + + BoundStatement BindJoin(Binder &parent, TableRef &ref); + BoundStatement Bind(BaseTableRef &ref); + BoundStatement Bind(BoundRefWrapper &ref); + BoundStatement Bind(JoinRef &ref); + BoundStatement Bind(SubqueryRef &ref); + BoundStatement Bind(TableFunctionRef &ref); + BoundStatement Bind(EmptyTableRef &ref); + BoundStatement Bind(DelimGetRef &ref); + BoundStatement Bind(ExpressionListRef &ref); + BoundStatement Bind(ColumnDataRef &ref); + BoundStatement Bind(PivotRef &expr); + BoundStatement Bind(ShowRef &ref); unique_ptr BindPivot(PivotRef &expr, vector> all_columns); unique_ptr BindUnpivot(Binder &child_binder, PivotRef &expr, vector> all_columns, unique_ptr &where_clause); - unique_ptr BindBoundPivot(PivotRef &expr); + BoundStatement BindBoundPivot(PivotRef &expr); void ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries); void ExtractUnpivotColumnName(ParsedExpression &expr, vector &result); @@ -387,26 +445,14 @@ class Binder : public enable_shared_from_this { bool BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error); - void BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery); - unique_ptr BindTableFunction(TableFunction &function, vector parameters); - unique_ptr BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, - vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names); - - unique_ptr CreatePlan(BoundBaseTableRef &ref); + BoundStatement &subquery, ErrorData &error); + void BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery); + BoundStatement BindTableFunction(TableFunction &function, vector parameters); + BoundStatement BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, vector input_table_names); + unique_ptr CreatePlan(BoundJoinRef &ref); - unique_ptr CreatePlan(BoundSubqueryRef &ref); - unique_ptr CreatePlan(BoundTableFunction &ref); - unique_ptr CreatePlan(BoundEmptyTableRef &ref); - unique_ptr CreatePlan(BoundExpressionListRef &ref); - unique_ptr CreatePlan(BoundColumnDataRef &ref); - unique_ptr CreatePlan(BoundCTERef &ref); - unique_ptr CreatePlan(BoundPivotRef &ref); - unique_ptr CreatePlan(BoundDelimGetRef &ref); BoundStatement BindCopyTo(CopyStatement &stmt, const CopyFunction &function, CopyToType copy_to_type); BoundStatement BindCopyFrom(CopyStatement &stmt, const CopyFunction &function); @@ -426,12 +472,12 @@ class Binder : public enable_shared_from_this { void PlanSubqueries(unique_ptr &expr, unique_ptr &root); unique_ptr PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root); unique_ptr PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated_columns, + CorrelatedColumns &correlated_columns, JoinType join_type = JoinType::INNER, unique_ptr condition = nullptr); - unique_ptr CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, + unique_ptr CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op); BindingAlias FindBinding(const string &using_column, const string &join_side); @@ -441,8 +487,6 @@ class Binder : public enable_shared_from_this { BindingAlias RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, const string &column_name, const string &join_side); - void AddCTEMap(CommonTableExpressionMap &cte_map); - void ExpandStarExpressions(vector> &select_list, vector> &new_select_list); void ExpandStarExpression(unique_ptr expr, vector> &new_select_list); @@ -463,14 +507,14 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); - unique_ptr BindSelectNode(SelectNode &statement, unique_ptr from_table); + BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); - unique_ptr BindShowQuery(ShowRef &ref); - unique_ptr BindShowTable(ShowRef &ref); - unique_ptr BindSummarize(ShowRef &ref); + BoundStatement BindShowQuery(ShowRef &ref); + BoundStatement BindShowTable(ShowRef &ref); + BoundStatement BindSummarize(ShowRef &ref); void BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, vector &named_column_map, vector &expected_types, @@ -491,6 +535,9 @@ class Binder : public enable_shared_from_this { static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, const string &tname); + BoundCTEData PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement); + BoundStatement FinishCTE(BoundCTEData &bound_cte, BoundStatement child_data); + private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); }; diff --git a/src/duckdb/src/include/duckdb/planner/binding_alias.hpp b/src/duckdb/src/include/duckdb/planner/binding_alias.hpp index 2d85b521e..9d75738bc 100644 --- a/src/duckdb/src/include/duckdb/planner/binding_alias.hpp +++ b/src/duckdb/src/include/duckdb/planner/binding_alias.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb/common/common.hpp" #include "duckdb/common/case_insensitive_map.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp index cd5a78b6a..76c461e78 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp @@ -17,13 +17,8 @@ namespace duckdb { //! Bound equivalent of QueryNode class BoundQueryNode { public: - explicit BoundQueryNode(QueryNodeType type) : type(type) { - } - virtual ~BoundQueryNode() { - } + virtual ~BoundQueryNode() = default; - //! The type of the query node, either SetOperation or Select - QueryNodeType type; //! The result modifiers that should be applied to this query node vector> modifiers; @@ -34,23 +29,6 @@ class BoundQueryNode { public: virtual idx_t GetRootIndex() = 0; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp b/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp index 853384e0b..5f26dd8f7 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_result_modifier.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/limits.hpp" +#include "duckdb/parser/group_by_node.hpp" #include "duckdb/parser/result_modifier.hpp" #include "duckdb/planner/bound_statement.hpp" #include "duckdb/planner/expression.hpp" @@ -155,8 +156,9 @@ class BoundOrderModifier : public BoundResultModifier { //! Remove unneeded/duplicate order elements. //! Returns true of orders is not empty. - static bool Simplify(vector &orders, const vector> &groups); - bool Simplify(const vector> &groups); + static bool Simplify(vector &orders, const vector> &groups, + optional_ptr> grouping_sets); + bool Simplify(const vector> &groups, optional_ptr> grouping_sets); }; enum class DistinctType : uint8_t { DISTINCT = 0, DISTINCT_ON = 1 }; diff --git a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp index bb1f7bfec..23fae54d6 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp @@ -9,17 +9,31 @@ #pragma once #include "duckdb/common/string.hpp" +#include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { class LogicalOperator; struct LogicalType; +struct BoundStatement; +class ParsedExpression; +class Binder; + +struct ExtraBoundInfo { + SetOperationType setop_type = SetOperationType::NONE; + vector> child_binders; + vector bound_children; + vector> original_expressions; +}; struct BoundStatement { unique_ptr plan; vector types; vector names; + ExtraBoundInfo extra_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp b/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp deleted file mode 100644 index 0a831c54a..000000000 --- a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/bound_tableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/enums/tableref_type.hpp" -#include "duckdb/parser/parsed_data/sample_options.hpp" - -namespace duckdb { - -class BoundTableRef { -public: - explicit BoundTableRef(TableReferenceType type) : type(type) { - } - virtual ~BoundTableRef() { - } - - //! The type of table reference - TableReferenceType type; - //! The sample options (if any) - unique_ptr sample; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp index bd75aac19..862ef5a11 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -16,8 +16,6 @@ namespace duckdb { class BoundQueryNode; class BoundSelectNode; class BoundSetOperationNode; -class BoundRecursiveCTENode; -class BoundCTENode; //===--------------------------------------------------------------------===// // Expressions @@ -45,18 +43,7 @@ class BoundWindowExpression; //===--------------------------------------------------------------------===// // TableRefs //===--------------------------------------------------------------------===// -class BoundTableRef; - -class BoundBaseTableRef; class BoundJoinRef; -class BoundSubqueryRef; -class BoundTableFunction; -class BoundEmptyTableRef; -class BoundExpressionListRef; -class BoundColumnDataRef; -class BoundCTERef; -class BoundPivotRef; - class BoundMergeIntoAction; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp index 58d136372..9b729ebe5 100644 --- a/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp +++ b/src/duckdb/src/include/duckdb/planner/constraints/bound_unique_constraint.hpp @@ -22,7 +22,6 @@ class BoundUniqueConstraint : public BoundConstraint { BoundUniqueConstraint(vector keys_p, physical_index_set_t key_set_p, const bool is_primary_key) : BoundConstraint(ConstraintType::UNIQUE), keys(std::move(keys_p)), key_set(std::move(key_set_p)), is_primary_key(is_primary_key) { - #ifdef DEBUG D_ASSERT(keys.size() == key_set.size()); for (auto &key : keys) { diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp index aa07a67b9..35792c8d4 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp @@ -29,7 +29,7 @@ class BoundSubqueryExpression : public Expression { //! The binder used to bind the subquery node shared_ptr binder; //! The bound subquery node - unique_ptr subquery; + BoundStatement subquery; //! The subquery type SubqueryType subquery_type; //! the child expressions to compare with (in case of IN, ANY, ALL operators) diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp index b2712f3bf..3232c4d99 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp @@ -10,9 +10,8 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/stack_checker.hpp" -#include "duckdb/common/exception/binder_exception.hpp" #include "duckdb/common/error_data.hpp" -#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/exception/binder_exception.hpp" #include "duckdb/parser/expression/bound_expression.hpp" #include "duckdb/parser/expression/lambdaref_expression.hpp" #include "duckdb/parser/parsed_expression.hpp" @@ -74,6 +73,18 @@ class ExpressionBinder { ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false); virtual ~ExpressionBinder(); + virtual bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, + BindResult &result, unique_ptr &expr_ptr) { + return false; + } + + virtual bool DoesColumnAliasExist(const ColumnRefExpression &colref) { + return false; + } + + // Returns true if the ColumnRef could be an alias reference (unqualified or qualified with table name "alias") + static bool IsPotentialAlias(const ColumnRefExpression &colref); + //! The target type that should result from the binder. If the result is not of this type, a cast to this type will //! be added. Defaults to INVALID. LogicalType target_type; @@ -133,9 +144,6 @@ class ExpressionBinder { static bool ContainsType(const LogicalType &type, LogicalTypeId target); static LogicalType ExchangeType(const LogicalType &type, LogicalTypeId target, LogicalType new_type); - virtual bool TryBindAlias(ColumnRefExpression &colref, bool root_expression, BindResult &result); - virtual bool QualifyColumnAlias(const ColumnRefExpression &colref); - //! Bind the given expression. Unlike Bind(), this does *not* mute the given ParsedExpression. //! Exposed to be used from sub-binders that aren't subclasses of ExpressionBinder. virtual BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, @@ -150,6 +158,9 @@ class ExpressionBinder { static LogicalType GetExpressionReturnType(const Expression &expr); + //! Returns true if the function name is an alias for the UNNEST function + static bool IsUnnestFunction(const string &function_name); + private: //! Current stack depth idx_t stack_depth = DConstants::INVALID_INDEX; @@ -162,7 +173,8 @@ class ExpressionBinder { BindResult BindExpression(CaseExpression &expr, idx_t depth); BindResult BindExpression(CollateExpression &expr, idx_t depth); BindResult BindExpression(CastExpression &expr, idx_t depth); - BindResult BindExpression(ColumnRefExpression &expr, idx_t depth, bool root_expression); + BindResult BindExpression(ColumnRefExpression &expr, idx_t depth, bool root_expression, + unique_ptr &expr_ptr); BindResult BindExpression(LambdaRefExpression &expr, idx_t depth); BindResult BindExpression(ComparisonExpression &expr, idx_t depth); BindResult BindExpression(ConjunctionExpression &expr, idx_t depth); @@ -216,8 +228,6 @@ class ExpressionBinder { optional_ptr stored_binder; vector bound_columns; - //! Returns true if the function name is an alias for the UNNEST function - static bool IsUnnestFunction(const string &function_name); BindResult TryBindLambdaOrJson(FunctionExpression &function, idx_t depth, CatalogEntry &func, const LambdaSyntaxType syntax_type); diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp index d70e65c7d..5e068ce76 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp @@ -24,7 +24,7 @@ class ColumnAliasBinder { bool BindAlias(ExpressionBinder &enclosing_binder, unique_ptr &expr_ptr, idx_t depth, bool root_expression, BindResult &result); // Check if the column reference is an SELECT item alias. - bool QualifyColumnAlias(const ColumnRefExpression &colref); + bool DoesColumnAliasExist(const ColumnRefExpression &colref); private: SelectBindState &bind_state; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp index 2fab45e30..d9a186944 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp @@ -33,9 +33,11 @@ class GroupBinder : public ExpressionBinder { string UnsupportedAggregateMessage() override; BindResult BindSelectRef(idx_t entry); - BindResult BindColumnRef(ColumnRefExpression &expr); + BindResult BindColumnRef(ColumnRefExpression &expr, unique_ptr &expr_ptr); BindResult BindConstant(ConstantExpression &expr); - bool TryBindAlias(ColumnRefExpression &colref, bool root_expression, BindResult &result) override; + + bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, + unique_ptr &expr_ptr) override; SelectNode &node; SelectBindState &bind_state; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp index eb68a0cdf..55f046cd7 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp @@ -24,7 +24,7 @@ class LateralBinder : public ExpressionBinder { return !correlated_columns.empty(); } - static void ReduceExpressionDepth(LogicalOperator &op, const vector &info); + static void ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &info); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, @@ -37,7 +37,7 @@ class LateralBinder : public ExpressionBinder { void ExtractCorrelatedColumns(Expression &expr); private: - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp index d11d94731..956c66fab 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp index d191cd3ef..c070a3325 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp @@ -17,12 +17,14 @@ class SelectBinder : public BaseSelectBinder { public: SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); + bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, + unique_ptr &expr_ptr) override; + bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; + protected: void ThrowIfUnnestInLambda(const ColumnBinding &column_binding) override; BindResult BindUnnest(FunctionExpression &function, idx_t depth, bool root_expression) override; - BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) override; - bool QualifyColumnAlias(const ColumnRefExpression &colref) override; unique_ptr GetSQLValueFunction(const string &column_name) override; protected: diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp new file mode 100644 index 000000000..95ecea934 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/expression_binder/try_operator_binder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/expression_binder.hpp" + +namespace duckdb { + +//! This binder is used for the TRY expression +class TryOperatorBinder : public ExpressionBinder { + friend class SelectBinder; + +public: + TryOperatorBinder(Binder &binder, ClientContext &context); + +protected: + BindResult BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, idx_t depth) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp index 20c443120..f6e8db403 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp @@ -24,7 +24,11 @@ class WhereBinder : public ExpressionBinder { bool root_expression = false) override; string UnsupportedAggregateMessage() override; - bool QualifyColumnAlias(const ColumnRefExpression &colref) override; + + bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, + unique_ptr &expr_ptr) override; + + bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; private: BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); diff --git a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp index 5b2e2e8a8..52599a1c8 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp @@ -14,8 +14,6 @@ #include namespace duckdb { -class BoundQueryNode; -class BoundTableRef; class ExpressionIterator { public: @@ -47,18 +45,4 @@ class ExpressionIterator { } }; -class BoundNodeVisitor { -public: - virtual ~BoundNodeVisitor() = default; - - virtual void VisitBoundQueryNode(BoundQueryNode &op); - virtual void VisitBoundTableRef(BoundTableRef &ref); - virtual void VisitExpression(unique_ptr &expression); - -protected: - // The VisitExpressionChildren method is called at the end of every call to VisitExpression to recursively visit all - // expressions in an expression tree. It can be overloaded to prevent automatically visiting the entire tree. - virtual void VisitExpressionChildren(Expression &expression); -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp new file mode 100644 index 000000000..bff73c00a --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/bloom_filter +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +class BloomFilter { +public: + BloomFilter() = default; + void Initialize(ClientContext &context_p, idx_t number_of_rows); + + void InsertHashes(const Vector &hashes_v, idx_t count) const; + idx_t LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, idx_t count) const; + + void InsertOne(hash_t hash) const; + bool LookupOne(hash_t hash) const; + + bool IsInitialized() const { + return initialized; + } + +private: + idx_t num_sectors; + uint64_t bitmask; // num_sectors - 1 -> used to get the sector offset + + bool initialized = false; + AllocatedData buf_; + uint64_t *bf; +}; + +class BFTableFilter final : public TableFilter { +private: + BloomFilter &filter; + + bool filters_null_values; + string key_column_name; + LogicalType key_type; + +public: + static constexpr auto TYPE = TableFilterType::BLOOM_FILTER; + +public: + explicit BFTableFilter(BloomFilter &filter_p, const bool filters_null_values_p, const string &key_column_name_p, + const LogicalType &key_type_p) + : TableFilter(TYPE), filter(filter_p), filters_null_values(filters_null_values_p), + key_column_name(key_column_name_p), key_type(key_type_p) { + } + + //! If the join condition is e.g. "A = B", the bf will filter null values. + //! If the condition is "A is B" the filter will let nulls pass + bool FiltersNullValues() const { + return filters_null_values; + } + + LogicalType GetKeyType() const { + return key_type; + } + + string ToString(const string &column_name) const override; + + // Filters by first hashing and then probing the bloom filter. The &sel will hold + // the remaining tuples, &approved_tuple_count will hold the approved count. + idx_t Filter(Vector &keys_v, SelectionVector &sel, idx_t &approved_tuple_count, BFTableFilterState &state) const; + bool FilterValue(const Value &value) const; + + FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; + +private: + static void HashInternal(Vector &keys_v, const SelectionVector &sel, const idx_t approved_count, + BFTableFilterState &state); + + bool Equals(const TableFilter &other) const override; + unique_ptr Copy() const override; + unique_ptr ToExpression(const Expression &column) const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/list.hpp b/src/duckdb/src/include/duckdb/planner/filter/list.hpp index 2b76b4a80..8b0582f31 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/list.hpp @@ -5,3 +5,5 @@ #include "duckdb/planner/filter/null_filter.hpp" #include "duckdb/planner/filter/optional_filter.hpp" #include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp index 089242db8..71ee945d6 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp @@ -10,12 +10,13 @@ #pragma once #include "duckdb/planner/table_filter.hpp" +#include "duckdb/planner/table_filter_state.hpp" namespace duckdb { class OptionalFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::OPTIONAL_FILTER; + static constexpr auto TYPE = TableFilterType::OPTIONAL_FILTER; public: explicit OptionalFilter(unique_ptr filter = nullptr); @@ -30,6 +31,17 @@ class OptionalFilter : public TableFilter { FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + + virtual void FiltersNullValues(const LogicalType &type, bool &filters_nulls, bool &filters_valid_values, + TableFilterState &filter_state) const { + } + + virtual unique_ptr InitializeState(ClientContext &context) const { + return make_uniq(); + } + + virtual idx_t FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, + TableFilterState &filter_state, idx_t scan_count, idx_t &approved_tuple_count) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp new file mode 100644 index 000000000..613844641 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/selectivity_optional_filter +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/filter/optional_filter.hpp" + +namespace duckdb { + +struct SelectivityOptionalFilterState final : public TableFilterState { + enum class FilterStatus { + ACTIVE, + PAUSED_DUE_TO_ZONE_MAP_STATS, // todo: use this to disable the filter for one zone map based on CheckStatistics + PAUSED_DUE_TO_HIGH_SELECTIVITY + }; + + struct SelectivityStats { + idx_t tuples_accepted; + idx_t tuples_processed; + idx_t vectors_processed; + + idx_t n_vectors_to_check; + float selectivity_threshold; + + FilterStatus status; + + SelectivityStats(idx_t n_vectors_to_check, float selectivity_threshold); + void Update(idx_t accepted, idx_t processed); + bool IsActive() const; + double GetSelectivity() const; + }; + + unique_ptr child_state; + SelectivityStats stats; + + explicit SelectivityOptionalFilterState(unique_ptr child_state, const idx_t n_vectors_to_check, + const float selectivity_threshold) + : child_state(std::move(child_state)), stats(n_vectors_to_check, selectivity_threshold) { + } +}; + +class SelectivityOptionalFilter final : public OptionalFilter { +public: + static constexpr auto MIN_MAX_THRESHOLD = 0.75f; + static constexpr idx_t MIN_MAX_CHECK_N = 30; + + static constexpr float BF_THRESHOLD = 0.25f; + static constexpr idx_t BF_CHECK_N = 75; + + float selectivity_threshold; + idx_t n_vectors_to_check; + + SelectivityOptionalFilter(unique_ptr filter, float selectivity_threshold, idx_t n_vectors_to_check); + +public: + unique_ptr Copy() const override; + FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + void FiltersNullValues(const LogicalType &type, bool &filters_nulls, bool &filters_valid_values, + TableFilterState &filter_state) const override; + unique_ptr InitializeState(ClientContext &context) const override; + idx_t FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, + TableFilterState &filter_state, idx_t scan_count, idx_t &approved_tuple_count) const override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp index e7f533bdd..743a9153b 100644 --- a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp +++ b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp @@ -45,6 +45,7 @@ class LogicalOperator { public: virtual vector GetColumnBindings(); + virtual idx_t GetRootIndex(); static string ColumnBindingsToString(const vector &bindings); void PrintColumnBindings(); static vector GenerateColumnBindings(idx_t table_idx, idx_t column_count); diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp index 6d27b679e..6b4ee004e 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_column_data_get.hpp @@ -14,6 +14,8 @@ namespace duckdb { +class ManagedResultSet; + //! LogicalColumnDataGet represents a scan operation from a ColumnDataCollection class LogicalColumnDataGet : public LogicalOperator { public: diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp index f0d3ca404..35ec90266 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_comparison_join.hpp @@ -52,7 +52,7 @@ class LogicalComparisonJoin : public LogicalJoin { unique_ptr left_child, unique_ptr right_child, unique_ptr condition); - static unique_ptr CreateJoin(ClientContext &context, JoinType type, JoinRefType ref_type, + static unique_ptr CreateJoin(JoinType type, JoinRefType ref_type, unique_ptr left_child, unique_ptr right_child, vector conditions, diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp index cd2ed3c21..0548cd4e7 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp @@ -35,6 +35,6 @@ class LogicalCTE : public LogicalOperator { string ctename; idx_t table_index; idx_t column_count; - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp index 724f2bc57..5e4c83919 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp @@ -27,7 +27,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: explicit LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); explicit LogicalDependentJoin(JoinType type); @@ -35,7 +35,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { //! The conditions of the join unique_ptr join_condition; //! The list of columns that have correlations with the right - vector correlated_columns; + CorrelatedColumns correlated_columns; SubqueryType subquery_type = SubqueryType::INVALID; bool perform_delim = true; @@ -51,7 +51,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: static unique_ptr Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp deleted file mode 100644 index cbfdecd1f..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class BoundCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; - -public: - BoundCTENode() : BoundQueryNode(QueryNodeType::CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - //! The cte node - unique_ptr query; - //! The child node - unique_ptr child; - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the query side of the CTE - shared_ptr query_binder; - //! The binder used by the child side of the CTE - shared_ptr child_binder; - - CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - -public: - idx_t GetRootIndex() override { - return child->GetRootIndex(); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp deleted file mode 100644 index 3da295e2a..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_recursive_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -//! Bound equivalent of SetOperationNode -class BoundRecursiveCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; - -public: - BoundRecursiveCTENode() : BoundQueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - bool union_all; - //! The left side of the set operation - unique_ptr left; - //! The right side of the set operation - unique_ptr right; - //! Target columns for the recursive key variant - vector> key_targets; - - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the left side of the set operation - shared_ptr left_binder; - //! The binder used by the right side of the set operation - shared_ptr right_binder; - -public: - idx_t GetRootIndex() override { - return setop_index; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp index b3a22966a..3fdc186e9 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" @@ -36,18 +35,12 @@ struct BoundUnnestNode { //! Bound equivalent of SelectNode class BoundSelectNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; - -public: - BoundSelectNode() : BoundQueryNode(QueryNodeType::SELECT_NODE) { - } - //! Bind information SelectBindState bind_state; //! The projection list vector> select_list; //! The FROM clause - unique_ptr from_table; + BoundStatement from_table; //! The WHERE clause unique_ptr where_clause; //! list of groups diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 01fa37caf..675007b50 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -14,28 +14,17 @@ namespace duckdb { -struct BoundSetOpChild { - unique_ptr node; - shared_ptr binder; - //! Exprs used by the UNION BY NAME operations to add a new projection - vector> reorder_expressions; -}; - //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; - -public: - BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { - } - //! The type of set operation SetOperationType setop_type = SetOperationType::NONE; //! whether the ALL modifier was used or not bool setop_all = false; //! The bound children - vector bound_children; + vector bound_children; + //! Child binders + vector> child_binders; //! Index used by the set operation idx_t setop_index; diff --git a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp index 5c7dbda94..dcac81248 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp @@ -1,4 +1,2 @@ -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp index 2f343e901..5fa37d4ac 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp @@ -18,7 +18,7 @@ namespace duckdb { //! The FlattenDependentJoins class is responsible for pushing the dependent join down into the plan to create a //! flattened subquery struct FlattenDependentJoins { - FlattenDependentJoins(Binder &binder, const vector &correlated, bool perform_delim = true, + FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim = true, bool any_join = false, optional_ptr parent = nullptr); static unique_ptr DecorrelateIndependent(Binder &binder, unique_ptr plan); @@ -33,7 +33,7 @@ struct FlattenDependentJoins { bool parent_is_dependent_join = false); //! Mark entire subtree of Logical Operators as correlated by adding them to the has_correlated_expressions map. - bool MarkSubtreeCorrelated(LogicalOperator &op); + bool MarkSubtreeCorrelated(LogicalOperator &op, idx_t cte_index); //! Push the dependent join down a LogicalOperator unique_ptr PushDownDependentJoin(unique_ptr plan, @@ -47,7 +47,7 @@ struct FlattenDependentJoins { reference_map_t has_correlated_expressions; column_binding_map_t correlated_map; column_binding_map_t replacement_map; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; vector delim_types; bool perform_delim; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp index 6b238ffcc..81a097b49 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp @@ -16,7 +16,7 @@ namespace duckdb { //! Helper class to recursively detect correlated expressions inside a single LogicalOperator class HasCorrelatedExpressions : public LogicalOperatorVisitor { public: - explicit HasCorrelatedExpressions(const vector &correlated, bool lateral = false, + explicit HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral = false, idx_t lateral_depth = 0); void VisitOperator(LogicalOperator &op) override; @@ -28,7 +28,7 @@ class HasCorrelatedExpressions : public LogicalOperatorVisitor { unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; // Tracks number of nested laterals idx_t lateral_depth; }; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp index e2c507e73..323f3b9b4 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp @@ -17,13 +17,15 @@ namespace duckdb { //! Helper class to rewrite correlated cte scans within a single LogicalOperator class RewriteCTEScan : public LogicalOperatorVisitor { public: - RewriteCTEScan(idx_t table_index, const vector &correlated_columns); + RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns, + bool rewrite_dependent_joins = false); void VisitOperator(LogicalOperator &op) override; private: idx_t table_index; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; + bool rewrite_dependent_joins = false; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_binding.hpp b/src/duckdb/src/include/duckdb/planner/table_binding.hpp index 9aedc7e70..836f52c41 100644 --- a/src/duckdb/src/include/duckdb/planner/table_binding.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_binding.hpp @@ -17,6 +17,7 @@ #include "duckdb/planner/binding_alias.hpp" #include "duckdb/common/column_index.hpp" #include "duckdb/common/table_column.hpp" +#include "duckdb/planner/bound_statement.hpp" namespace duckdb { class BindContext; @@ -26,30 +27,16 @@ class SubqueryRef; class LogicalGet; class TableCatalogEntry; class TableFunctionCatalogEntry; -class BoundTableFunction; class StandardEntry; struct ColumnBinding; -enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY }; +enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY, CTE }; //! A Binding represents a binding to a table, table-producing function or subquery with a specified table index. struct Binding { Binding(BindingType binding_type, BindingAlias alias, vector types, vector names, idx_t index); virtual ~Binding() = default; - //! The type of Binding - BindingType binding_type; - //! The alias of the binding - BindingAlias alias; - //! The table index of the binding - idx_t index; - //! The types of the bound columns - vector types; - //! Column names of the subquery - vector names; - //! Name -> index for the names - case_insensitive_map_t name_map; - public: bool TryGetBindingIndex(const string &column_name, column_t &column_index); column_t GetBindingIndex(const string &column_name); @@ -59,6 +46,14 @@ struct Binding { virtual optional_ptr GetStandardEntry(); string GetAlias() const; + BindingType GetBindingType(); + const BindingAlias &GetBindingAlias(); + idx_t GetIndex(); + const vector &GetColumnTypes(); + const vector &GetColumnNames(); + idx_t GetColumnCount(); + void SetColumnType(idx_t col_idx, LogicalType type); + static BindingAlias GetAlias(const string &explicit_alias, const StandardEntry &entry); static BindingAlias GetAlias(const string &explicit_alias, optional_ptr entry); @@ -78,6 +73,23 @@ struct Binding { } return reinterpret_cast(*this); } + +protected: + void Initialize(); + +protected: + //! The type of Binding + BindingType binding_type; + //! The alias of the binding + BindingAlias alias; + //! The table index of the binding + idx_t index; + //! The types of the bound columns + vector types; + //! Column names of the subquery + vector names; + //! Name -> index for the names + case_insensitive_map_t name_map; }; struct EntryBinding : public Binding { @@ -149,4 +161,44 @@ struct DummyBinding : public Binding { unique_ptr ParamToArg(ColumnRefExpression &col_ref); }; +enum class CTEType { CAN_BE_REFERENCED, CANNOT_BE_REFERENCED }; +struct CTEBinding; + +struct CTEBindState { + CTEBindState(Binder &parent_binder, QueryNode &cte_def, const vector &aliases); + ~CTEBindState(); + + Binder &parent_binder; + QueryNode &cte_def; + const vector &aliases; + idx_t active_binder_count; + shared_ptr query_binder; + BoundStatement query; + vector names; + vector types; + +public: + bool IsBound() const; + void Bind(CTEBinding &binding); +}; + +struct CTEBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::CTE; + +public: + CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, CTEType type); + CTEBinding(BindingAlias alias, shared_ptr bind_state, idx_t index); + +public: + bool CanBeReferenced() const; + bool IsReferenced() const; + void Reference(); + +private: + CTEType cte_type; + idx_t reference_count; + shared_ptr bind_state; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_filter.hpp b/src/duckdb/src/include/duckdb/planner/table_filter.hpp index 297e4e0f4..695947860 100644 --- a/src/duckdb/src/include/duckdb/planner/table_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_filter.hpp @@ -33,7 +33,8 @@ enum class TableFilterType : uint8_t { OPTIONAL_FILTER = 6, // executing filter is not required for query correctness IN_FILTER = 7, // col IN (C1, C2, C3, ...) DYNAMIC_FILTER = 8, // dynamic filters can be updated at run-time - EXPRESSION_FILTER = 9 // an arbitrary expression + EXPRESSION_FILTER = 9, // an arbitrary expression + BLOOM_FILTER = 10, // a probabilistic filter that can test whether a value is in a set of other value }; //! TableFilter represents a filter pushed down into the table scan. diff --git a/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp b/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp index 35bb66529..0c0ef1309 100644 --- a/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp @@ -52,4 +52,17 @@ struct ExpressionFilterState : public TableFilterState { ExpressionExecutor executor; }; +struct BFTableFilterState final : public TableFilterState { + idx_t current_capacity; + Vector hashes_v; + Vector found_v; + Vector keys_sliced_v; + SelectionVector bf_sel; + + explicit BFTableFilterState(const LogicalType &key_logical_type) + : current_capacity(STANDARD_VECTOR_SIZE), hashes_v(LogicalType::HASH), found_v(LogicalType::UBIGINT), + keys_sliced_v(key_logical_type), bf_sel(STANDARD_VECTOR_SIZE) { + } +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp deleted file mode 100644 index b1f7f6f46..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_basetableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" - -namespace duckdb { -class TableCatalogEntry; - -//! Represents a TableReference to a base table in the schema -class BoundBaseTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::BASE_TABLE; - -public: - BoundBaseTableRef(TableCatalogEntry &table, unique_ptr get) - : BoundTableRef(TableReferenceType::BASE_TABLE), table(table), get(std::move(get)) { - } - - TableCatalogEntry &table; - unique_ptr get; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp deleted file mode 100644 index 025bc4712..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_column_data_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/optionally_owned_ptr.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundColumnDataRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::COLUMN_DATA; - -public: - explicit BoundColumnDataRef(optionally_owned_ptr collection) - : BoundTableRef(TableReferenceType::COLUMN_DATA), collection(std::move(collection)) { - } - //! The (optionally owned) materialized column data to scan - optionally_owned_ptr collection; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp deleted file mode 100644 index 781402fbe..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_cteref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/enums/cte_materialize.hpp" - -namespace duckdb { - -class BoundCTERef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::CTE; - -public: - BoundCTERef(idx_t bind_index, idx_t cte_index) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index) { - } - - BoundCTERef(idx_t bind_index, idx_t cte_index, bool is_recurring) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index), - is_recurring(is_recurring) { - } - //! The set of columns bound to this base table reference - vector bound_columns; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; - //! The index of the cte - idx_t cte_index; - //! Is this a reference to the recurring table of a CTE - bool is_recurring = false; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp deleted file mode 100644 index 7b1022482..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_delimgetref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -class BoundDelimGetRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::DELIM_GET; - -public: - BoundDelimGetRef(idx_t bind_index, const vector &column_types_p) - : BoundTableRef(TableReferenceType::DELIM_GET), bind_index(bind_index), column_types(column_types_p) { - } - idx_t bind_index; - vector column_types; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp deleted file mode 100644 index 3a68f5166..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_dummytableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundEmptyTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EMPTY_FROM; - -public: - explicit BoundEmptyTableRef(idx_t bind_index) - : BoundTableRef(TableReferenceType::EMPTY_FROM), bind_index(bind_index) { - } - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp deleted file mode 100644 index 7fc563dda..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_expressionlistref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/expression.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundExpressionListRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EXPRESSION_LIST; - -public: - BoundExpressionListRef() : BoundTableRef(TableReferenceType::EXPRESSION_LIST) { - } - - //! The bound VALUES list - vector>> values; - //! The generated names of the values list - vector names; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp index 38c83c95f..87976ba30 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp @@ -11,19 +11,14 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/joinref_type.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" namespace duckdb { //! Represents a join -class BoundJoinRef : public BoundTableRef { +class BoundJoinRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::JOIN; - -public: - explicit BoundJoinRef(JoinRefType ref_type) - : BoundTableRef(TableReferenceType::JOIN), type(JoinType::INNER), ref_type(ref_type), lateral(false) { + explicit BoundJoinRef(JoinRefType ref_type) : type(JoinType::INNER), ref_type(ref_type), lateral(false) { } //! The binder used to bind the LHS of the join @@ -31,9 +26,9 @@ class BoundJoinRef : public BoundTableRef { //! The binder used to bind the RHS of the join shared_ptr right_binder; //! The left hand side of the join - unique_ptr left; + BoundStatement left; //! The right hand side of the join - unique_ptr right; + BoundStatement right; //! The join condition unique_ptr condition; //! Duplicate Eliminated Columns (if any) @@ -47,7 +42,7 @@ class BoundJoinRef : public BoundTableRef { //! Whether or not this is a lateral join bool lateral; //! The correlated columns of the right-side with the left-side - vector correlated_columns; + CorrelatedColumns correlated_columns; //! The mark index, for mark joins generated by the relational API idx_t mark_index {}; }; diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp index 3219f6307..5a2d68aa1 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/parser/tableref/pivotref.hpp" #include "duckdb/function/aggregate_function.hpp" @@ -30,19 +29,13 @@ struct BoundPivotInfo { static BoundPivotInfo Deserialize(Deserializer &deserializer); }; -class BoundPivotRef : public BoundTableRef { +class BoundPivotRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::PIVOT; - -public: - explicit BoundPivotRef() : BoundTableRef(TableReferenceType::PIVOT) { - } - idx_t bind_index; //! The binder used to bind the child of the pivot shared_ptr child_binder; //! The child node of the pivot - unique_ptr child; + BoundStatement child; //! The bound pivot info BoundPivotInfo bound_pivot; }; diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp deleted file mode 100644 index 4cb057e41..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp +++ /dev/null @@ -1,38 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_pos_join_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a positional join -class BoundPositionalJoinRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::POSITIONAL_JOIN; - -public: - BoundPositionalJoinRef() : BoundTableRef(TableReferenceType::POSITIONAL_JOIN), lateral(false) { - } - - //! The binder used to bind the LHS of the positional join - shared_ptr left_binder; - //! The binder used to bind the RHS of the positional join - shared_ptr right_binder; - //! The left hand side of the positional join - unique_ptr left; - //! The right hand side of the positional join - unique_ptr right; - //! Whether or not this is a lateral positional join - bool lateral; - //! The correlated columns of the right-side with the left-side - vector correlated_columns; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp deleted file mode 100644 index 2d1061c98..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_subqueryref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundSubqueryRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; - -public: - BoundSubqueryRef(shared_ptr binder_p, unique_ptr subquery) - : BoundTableRef(TableReferenceType::SUBQUERY), binder(std::move(binder_p)), subquery(std::move(subquery)) { - } - - //! The binder used to bind the subquery - shared_ptr binder; - //! The bound subquery node (if any) - unique_ptr subquery; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp deleted file mode 100644 index 6aafe2b36..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_table_function.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -//! Represents a reference to a table-producing function call -class BoundTableFunction : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::TABLE_FUNCTION; - -public: - explicit BoundTableFunction(unique_ptr get) - : BoundTableRef(TableReferenceType::TABLE_FUNCTION), get(std::move(get)) { - } - - unique_ptr get; - unique_ptr subquery; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp index 79a00ce62..dbc8394df 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp @@ -1,11 +1,2 @@ -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/parser/tableref/delimgetref.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp index 687e4af9b..6cf150366 100644 --- a/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp +++ b/src/duckdb/src/include/duckdb/storage/arena_allocator.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/allocator.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/types/string.hpp" +#include "duckdb/common/arena_containers/arena_ptr.hpp" namespace duckdb { @@ -84,6 +85,16 @@ class ArenaAllocator { return new (mem) T(std::forward(args)...); } + template + arena_ptr MakePtr(ARGS &&... args) { + return arena_ptr(Make(std::forward(args)...)); + } + + template + unsafe_arena_ptr MakeUnsafePtr(ARGS &&... args) { + return unsafe_arena_ptr(Make(std::forward(args)...)); + } + String MakeString(const char *data, const size_t len) { data_ptr_t mem = nullptr; diff --git a/src/duckdb/src/include/duckdb/storage/block.hpp b/src/duckdb/src/include/duckdb/storage/block.hpp index 3aa18a7bc..794fd6ff5 100644 --- a/src/duckdb/src/include/duckdb/storage/block.hpp +++ b/src/duckdb/src/include/duckdb/storage/block.hpp @@ -14,14 +14,15 @@ namespace duckdb { +class BlockAllocator; class Serializer; class Deserializer; class Block : public FileBuffer { public: - Block(Allocator &allocator, const block_id_t id, const idx_t block_size, const idx_t block_header_size); - Block(Allocator &allocator, block_id_t id, uint32_t internal_size, idx_t block_header_size); - Block(Allocator &allocator, const block_id_t id, BlockManager &block_manager); + Block(BlockAllocator &allocator, const block_id_t id, const idx_t block_size, const idx_t block_header_size); + Block(BlockAllocator &allocator, block_id_t id, uint32_t internal_size, idx_t block_header_size); + Block(BlockAllocator &allocator, const block_id_t id, BlockManager &block_manager); Block(FileBuffer &source, block_id_t id, idx_t block_header_size); block_id_t id; @@ -61,6 +62,15 @@ struct MetaBlockPointer { block_id_t GetBlockId() const; uint32_t GetBlockIndex() const; + bool operator==(const MetaBlockPointer &rhs) const { + return block_pointer == rhs.block_pointer && offset == rhs.offset; + } + + friend std::ostream &operator<<(std::ostream &os, const MetaBlockPointer &obj) { + return os << "{block_id: " << obj.GetBlockId() << " index: " << obj.GetBlockIndex() << " offset: " << obj.offset + << "}"; + } + void Serialize(Serializer &serializer) const; static MetaBlockPointer Deserialize(Deserializer &source); }; diff --git a/src/duckdb/src/include/duckdb/storage/block_allocator.hpp b/src/duckdb/src/include/duckdb/storage/block_allocator.hpp new file mode 100644 index 000000000..1fe298752 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/block_allocator.hpp @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/block_allocator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/optional_idx.hpp" + +namespace duckdb { + +class Allocator; +class AttachedDatabase; +class DatabaseInstance; +class BlockAllocatorThreadLocalState; +struct BlockQueue; + +class BlockAllocator { + friend class BlockAllocatorThreadLocalState; + +public: + BlockAllocator(Allocator &allocator, idx_t block_size, idx_t virtual_memory_size, idx_t physical_memory_size); + ~BlockAllocator(); + +public: + static BlockAllocator &Get(DatabaseInstance &db); + static BlockAllocator &Get(AttachedDatabase &db); + + //! Resize physical memory (can only be increased) + void Resize(idx_t new_physical_memory_size); + + //! Allocation functions (same API as Allocator) + data_ptr_t AllocateData(idx_t size) const; + void FreeData(data_ptr_t pointer, idx_t size) const; + data_ptr_t ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t new_size) const; + + //! Flush outstanding allocations + bool SupportsFlush() const; + void ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) const; + void FlushAll(optional_idx extra_memory = optional_idx()) const; + +private: + bool IsActive() const; + bool IsEnabled() const; + bool IsInPool(data_ptr_t pointer) const; + + idx_t ModuloBlockSize(idx_t n) const; + idx_t DivBlockSize(idx_t n) const; + + uint32_t GetBlockID(data_ptr_t pointer) const; + data_ptr_t GetPointer(uint32_t block_id) const; + + void VerifyBlockID(uint32_t block_id) const; + + void FreeInternal(idx_t extra_memory) const; + void FreeContiguousBlocks(uint32_t block_id_start, uint32_t block_id_end_including) const; + +private: + //! Identifier + const hugeint_t uuid; + //! Fallback allocator + Allocator &allocator; + + //! Block size (power of two) + const idx_t block_size; + //! Shift for dividing by block size + const idx_t block_size_div_shift; + + //! Size of the virtual memory + const idx_t virtual_memory_size; + //! Pointer to the start of the virtual memory + const data_ptr_t virtual_memory_space; + + //! Mutex for modifying physical memory size + mutex physical_memory_lock; + //! Size of the physical memory + atomic physical_memory_size; + + //! Untouched block IDs + unsafe_unique_ptr untouched; + //! Touched by block IDs + unsafe_unique_ptr touched; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/block_manager.hpp b/src/duckdb/src/include/duckdb/storage/block_manager.hpp index 0fd9df675..0351dfb47 100644 --- a/src/duckdb/src/include/duckdb/storage/block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/block_manager.hpp @@ -24,6 +24,8 @@ class ClientContext; class DatabaseInstance; class MetadataManager; +enum class ConvertToPersistentMode { DESTRUCTIVE, THREAD_SAFE }; + //! BlockManager is an abstract representation to manage blocks on DuckDB. When writing or reading blocks, the //! BlockManager creates and accesses blocks. The concrete types implement specific block storage strategies. class BlockManager { @@ -37,16 +39,23 @@ class BlockManager { BufferManager &buffer_manager; public: + BufferManager &GetBufferManager() const { + return buffer_manager; + } //! Creates a new block inside the block manager virtual unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) = 0; virtual unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) = 0; //! Return the next free block id virtual block_id_t GetFreeBlockId() = 0; virtual block_id_t PeekFreeBlockId() = 0; + //! Returns the next free block id and immediately include it in the checkpoint + // Equivalent to calling GetFreeBlockId() followed by MarkBlockAsCheckpointed + virtual block_id_t GetFreeBlockIdForCheckpoint() = 0; + //! Returns whether or not a specified block is the root block virtual bool IsRootBlock(MetaBlockPointer root) = 0; - //! Mark a block as "free"; free blocks are immediately added to the free list and can be immediately overwritten - virtual void MarkBlockAsFree(block_id_t block_id) = 0; + //! Mark a block as included in the next checkpoint + virtual void MarkBlockACheckpointed(block_id_t block_id) = 0; //! Mark a block as "used"; either the block is removed from the free list, or the reference count is incremented virtual void MarkBlockAsUsed(block_id_t block_id) = 0; //! Mark a block as "modified"; modified blocks are added to the free list after a checkpoint (i.e. their data is @@ -95,14 +104,19 @@ class BlockManager { //! Register a block with the given block id in the base file shared_ptr RegisterBlock(block_id_t block_id); //! Convert an existing in-memory buffer into a persistent disk-backed block + //! If mode is set to destructive (default) - the old_block will be destroyed as part of this method + //! This can only be safely used when there is no other (lingering) usage of old_block + //! If there is concurrent usage of the block elsewhere - use the THREAD_SAFE mode which creates an extra copy shared_ptr ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block, BufferHandle old_handle); + shared_ptr old_block, BufferHandle old_handle, + ConvertToPersistentMode mode = ConvertToPersistentMode::DESTRUCTIVE); shared_ptr ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block); + shared_ptr old_block, + ConvertToPersistentMode mode = ConvertToPersistentMode::DESTRUCTIVE); void UnregisterBlock(BlockHandle &block); //! UnregisterBlock, only accepts non-temporary block ids - void UnregisterBlock(block_id_t id); + virtual void UnregisterBlock(block_id_t id); //! Returns a reference to the metadata manager of this block manager. MetadataManager &GetMetadataManager(); @@ -151,6 +165,9 @@ class BlockManager { virtual void VerifyBlocks(const unordered_map &block_usage_count) { } +protected: + bool BlockIsRegistered(block_id_t block_id); + public: template TARGET &Cast() { @@ -179,4 +196,11 @@ class BlockManager { //! Default to default_block_header_size for file-backed block managers. optional_idx block_header_size; }; + +struct BlockIdVisitor { + virtual ~BlockIdVisitor() = default; + + virtual void Visit(block_id_t block_id) = 0; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp index 9df6743e2..a77b34ef7 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp @@ -41,7 +41,8 @@ class BufferPool { friend class StandardBufferManager; public: - BufferPool(idx_t maximum_memory, bool track_eviction_timestamps, idx_t allocator_bulk_deallocation_flush_threshold); + BufferPool(BlockAllocator &block_allocator, idx_t maximum_memory, bool track_eviction_timestamps, + idx_t allocator_bulk_deallocation_flush_threshold); virtual ~BufferPool(); //! Set a new memory limit to the buffer pool, throws an exception if the new limit is too low and not enough @@ -160,6 +161,8 @@ class BufferPool { //! and only updates the global counter when the cache value exceeds a threshold. //! Therefore, the statistics may have slight differences from the actual memory usage. mutable MemoryUsage memory_usage; + //! The block allocator + BlockAllocator &block_allocator; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp index 619e89a5a..3d4f5e595 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp @@ -58,6 +58,7 @@ class BufferManager { virtual void ReAllocate(shared_ptr &handle, idx_t block_size) = 0; //! Pin a block handle. virtual BufferHandle Pin(shared_ptr &handle) = 0; + virtual BufferHandle Pin(const QueryContext &context, shared_ptr &handle) = 0; //! Pre-fetch a series of blocks. //! Using this function is a performance suggestion. virtual void Prefetch(vector> &handles) = 0; @@ -100,6 +101,8 @@ class BufferManager { //! Set a new swap limit. virtual void SetSwapLimit(optional_idx limit = optional_idx()); + //! Get the block manager used for in-memory data + virtual BlockManager &GetTemporaryBlockManager() = 0; //! Get the temporary file information of each temporary file. virtual vector GetTemporaryFiles(); //! Get the path to the temporary file directory. diff --git a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp index 99e949418..60b4f6b30 100644 --- a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp +++ b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp @@ -54,6 +54,8 @@ struct CachingFileHandle { DUCKDB_API bool CanSeek(); DUCKDB_API bool IsRemoteFile() const; DUCKDB_API bool OnDiskFile(); + DUCKDB_API idx_t SeekPosition(); + DUCKDB_API void Seek(idx_t location); private: //! Get the version tag of the file (for checking cache invalidation) @@ -61,7 +63,8 @@ struct CachingFileHandle { //! Tries to read from the cache, filling "overlapping_ranges" with ranges that overlap with the request. //! Returns an invalid BufferHandle if it fails BufferHandle TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges); + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range); //! Try to read from the specified range, return an invalid BufferHandle if it fails BufferHandle TryReadFromFileRange(const unique_ptr &guard, CachedFileRange &file_range, data_ptr_t &buffer, idx_t nr_bytes, idx_t location); @@ -108,6 +111,7 @@ class CachingFileSystem { friend struct CachingFileHandle; public: + // Notice, the provided [file_system] should be a raw, non-caching filesystem. DUCKDB_API CachingFileSystem(FileSystem &file_system, DatabaseInstance &db); DUCKDB_API ~CachingFileSystem(); diff --git a/src/duckdb/src/include/duckdb/storage/caching_file_system_wrapper.hpp b/src/duckdb/src/include/duckdb/storage/caching_file_system_wrapper.hpp new file mode 100644 index 000000000..d9430bf0f --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/caching_file_system_wrapper.hpp @@ -0,0 +1,145 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/caching_file_system_wrapper.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/storage/caching_file_system.hpp" + +namespace duckdb { + +// Forward declaration. +class DatabaseInstance; +class ClientContext; +class QueryContext; +class CachingFileSystemWrapper; +struct CachingFileHandle; + +//! Caching mode for CachingFileSystemWrapper. +//! By default only remote files will be cached, but it's also allowed to cache local for direct IO use case. +enum class CachingMode : uint8_t { + // Cache all files. + ALWAYS_CACHE = 0, + // Only cache remote files, bypass cache for local files. + CACHE_REMOTE_ONLY = 1, +}; + +//! CachingFileHandleWrapper wraps CachingFileHandle to conform to FileHandle API. +class CachingFileHandleWrapper : public FileHandle { + friend class CachingFileSystemWrapper; + +public: + DUCKDB_API CachingFileHandleWrapper(CachingFileSystemWrapper &file_system, unique_ptr handle, + FileOpenFlags flags); + DUCKDB_API ~CachingFileHandleWrapper() override; + + DUCKDB_API void Close() override; + +private: + unique_ptr caching_handle; +}; + +//! [CachingFileSystemWrapper] is an adapter class, which wraps [CachingFileSystem] to conform to FileSystem API. +//! Different from [CachingFileSystem], which owns cache content and returns a [BufferHandle] to achieve zero-copy on +//! read, the wrapper class always copies requested byted into the provided address. +//! +//! NOTICE: Currently only read and seek operations are supported, write operations are disabled. +class CachingFileSystemWrapper : public FileSystem { +public: + DUCKDB_API CachingFileSystemWrapper(FileSystem &file_system, DatabaseInstance &db, + CachingMode mode = CachingMode::CACHE_REMOTE_ONLY); + DUCKDB_API ~CachingFileSystemWrapper() override; + + DUCKDB_API static CachingFileSystemWrapper Get(ClientContext &context, + CachingMode mode = CachingMode::CACHE_REMOTE_ONLY); + + DUCKDB_API std::string GetName() const override; + + DUCKDB_API unique_ptr OpenFile(const string &path, FileOpenFlags flags, + optional_ptr opener = nullptr) override; + DUCKDB_API unique_ptr OpenFile(const OpenFileInfo &path, FileOpenFlags flags, + optional_ptr opener = nullptr); + + DUCKDB_API void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + DUCKDB_API void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + DUCKDB_API int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + DUCKDB_API int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + DUCKDB_API bool Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_bytes) override; + + DUCKDB_API int64_t GetFileSize(FileHandle &handle) override; + DUCKDB_API timestamp_t GetLastModifiedTime(FileHandle &handle) override; + DUCKDB_API string GetVersionTag(FileHandle &handle) override; + DUCKDB_API FileType GetFileType(FileHandle &handle) override; + DUCKDB_API FileMetadata Stats(FileHandle &handle) override; + DUCKDB_API void Truncate(FileHandle &handle, int64_t new_size) override; + DUCKDB_API void FileSync(FileHandle &handle) override; + + DUCKDB_API bool DirectoryExists(const string &directory, optional_ptr opener = nullptr) override; + DUCKDB_API void CreateDirectory(const string &directory, optional_ptr opener = nullptr) override; + DUCKDB_API void CreateDirectoriesRecursive(const string &path, optional_ptr opener = nullptr) override; + DUCKDB_API void RemoveDirectory(const string &directory, optional_ptr opener = nullptr) override; + + DUCKDB_API bool ListFiles(const string &directory, const std::function &callback, + FileOpener *opener = nullptr) override; + + DUCKDB_API void MoveFile(const string &source, const string &target, + optional_ptr opener = nullptr) override; + DUCKDB_API bool FileExists(const string &filename, optional_ptr opener = nullptr) override; + DUCKDB_API bool IsPipe(const string &filename, optional_ptr opener = nullptr) override; + DUCKDB_API void RemoveFile(const string &filename, optional_ptr opener = nullptr) override; + DUCKDB_API bool TryRemoveFile(const string &filename, optional_ptr opener = nullptr) override; + + DUCKDB_API string GetHomeDirectory() override; + DUCKDB_API string ExpandPath(const string &path) override; + DUCKDB_API string PathSeparator(const string &path) override; + + DUCKDB_API vector Glob(const string &path, FileOpener *opener = nullptr) override; + + DUCKDB_API void RegisterSubSystem(unique_ptr sub_fs) override; + DUCKDB_API void RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) override; + DUCKDB_API void UnregisterSubSystem(const string &name) override; + DUCKDB_API unique_ptr ExtractSubSystem(const string &name) override; + DUCKDB_API vector ListSubSystems() override; + DUCKDB_API bool CanHandleFile(const string &fpath) override; + + DUCKDB_API void Seek(FileHandle &handle, idx_t location) override; + DUCKDB_API void Reset(FileHandle &handle) override; + DUCKDB_API idx_t SeekPosition(FileHandle &handle) override; + + DUCKDB_API bool IsManuallySet() override; + DUCKDB_API bool CanSeek() override; + DUCKDB_API bool OnDiskFile(FileHandle &handle) override; + + DUCKDB_API unique_ptr OpenCompressedFile(QueryContext context, unique_ptr handle, + bool write) override; + + DUCKDB_API void SetDisabledFileSystems(const vector &names) override; + DUCKDB_API bool SubSystemIsDisabled(const string &name) override; + DUCKDB_API bool IsDisabledForPath(const string &path) override; + +protected: + DUCKDB_API unique_ptr OpenFileExtended(const OpenFileInfo &path, FileOpenFlags flags, + optional_ptr opener) override; + DUCKDB_API bool SupportsOpenFileExtended() const override; + DUCKDB_API bool ListFilesExtended(const string &directory, const std::function &callback, + optional_ptr opener) override; + DUCKDB_API bool SupportsListFilesExtended() const override; + +private: + bool ShouldUseCache(const string &path) const; + + // Return an optional caching file handle, if certain filepath is cached. + CachingFileHandle *GetCachingHandleIfPossible(FileHandle &handle); + + CachingFileSystem caching_file_system; + FileSystem &underlying_file_system; + CachingMode caching_mode; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/checkpoint_options.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/checkpoint_options.hpp new file mode 100644 index 000000000..2ccc15c70 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/checkpoint_options.hpp @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/checkpoint/checkpoint_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/checkpoint_type.hpp" + +namespace duckdb { + +struct CheckpointOptions { + CheckpointOptions() + : wal_action(CheckpointWALAction::DONT_DELETE_WAL), action(CheckpointAction::CHECKPOINT_IF_REQUIRED), + type(CheckpointType::FULL_CHECKPOINT), transaction_id(MAX_TRANSACTION_ID) { + } + + CheckpointWALAction wal_action; + CheckpointAction action; + CheckpointType type; + transaction_t transaction_id; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp index ba1eea5b2..b1cea22b8 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp @@ -31,7 +31,7 @@ class RowGroupWriter { return compression_types; } - virtual CheckpointType GetCheckpointType() const = 0; + virtual CheckpointOptions GetCheckpointOptions() const = 0; virtual WriteStream &GetPayloadWriter() = 0; virtual MetaBlockPointer GetMetaBlockPointer() = 0; virtual optional_ptr GetMetadataManager() = 0; @@ -58,7 +58,7 @@ class SingleFileRowGroupWriter : public RowGroupWriter { TableDataWriter &writer, MetadataWriter &table_data_writer); public: - CheckpointType GetCheckpointType() const override; + CheckpointOptions GetCheckpointOptions() const override; WriteStream &GetPayloadWriter() override; MetaBlockPointer GetMetaBlockPointer() override; optional_ptr GetMetadataManager() override; diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp index 5ff309563..a13c72943 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/string_checkpoint_state.hpp @@ -55,9 +55,6 @@ struct UncompressedStringSegmentState : public CompressedSegmentState { string GetSegmentInfo() const override; - vector GetAdditionalBlocks() const override; - void Cleanup(BlockManager &manager); - private: mutex block_lock; unordered_map> handles; diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp index 30dd141b3..7c960098f 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp @@ -35,8 +35,15 @@ class TableDataWriter { virtual unique_ptr GetRowGroupWriter(RowGroup &row_group) = 0; virtual void AddRowGroup(RowGroupPointer &&row_group_pointer, unique_ptr writer); - virtual CheckpointType GetCheckpointType() const = 0; + virtual CheckpointOptions GetCheckpointOptions() const = 0; + virtual void FlushPartialBlocks() = 0; virtual MetadataManager &GetMetadataManager() = 0; + bool CanOverrideBaseStats() const { + return override_base_stats; + } + void SetCannotOverrideStats() { + override_base_stats = false; + } DatabaseInstance &GetDatabase(); unique_ptr CreateTaskExecutor(); @@ -46,6 +53,7 @@ class TableDataWriter { optional_ptr context; //! Pointers to the start of each row group. vector row_group_pointers; + bool override_base_stats = true; }; class SingleFileTableDataWriter : public TableDataWriter { @@ -58,7 +66,8 @@ class SingleFileTableDataWriter : public TableDataWriter { void FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) override; unique_ptr GetRowGroupWriter(RowGroup &row_group) override; - CheckpointType GetCheckpointType() const override; + CheckpointOptions GetCheckpointOptions() const override; + void FlushPartialBlocks() override; MetadataManager &GetMetadataManager() override; private: diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp index 318d93abf..e7b220da1 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp @@ -99,7 +99,7 @@ class SingleFileCheckpointWriter final : public CheckpointWriter { public: SingleFileCheckpointWriter(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, - CheckpointType checkpoint_type); + CheckpointOptions options); void CreateCheckpoint() override; @@ -108,8 +108,8 @@ class SingleFileCheckpointWriter final : public CheckpointWriter { unique_ptr GetTableDataWriter(TableCatalogEntry &table) override; BlockManager &GetBlockManager(); - CheckpointType GetCheckpointType() const { - return checkpoint_type; + CheckpointOptions GetCheckpointOptions() const { + return options; } optional_ptr GetClientContext() const { return context; @@ -128,7 +128,7 @@ class SingleFileCheckpointWriter final : public CheckpointWriter { //! an entire checkpoint. PartialBlockManager partial_block_manager; //! Checkpoint type - CheckpointType checkpoint_type; + CheckpointOptions options; //! Block usage count for verification purposes unordered_map verify_block_usage_count; }; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp index 68c8c28a9..2fca804ad 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp @@ -60,9 +60,9 @@ struct AlpCombination { }; template -class AlpCompressionState { +class AlpCompressionData { public: - AlpCompressionState() : vector_encoding_indices(0, 0), exceptions_count(0), bit_width(0) { + AlpCompressionData() : vector_encoding_indices(0, 0), exceptions_count(0), bit_width(0) { } void Reset() { @@ -86,11 +86,21 @@ class AlpCompressionState { uint16_t exceptions_positions[AlpConstants::ALP_VECTOR_SIZE]; vector best_k_combinations; uint8_t values_encoded[AlpConstants::ALP_VECTOR_SIZE * 8]; + using EXACT_TYPE = typename FloatingToExact::TYPE; + + idx_t RequiredSpace() const { + idx_t required_space = + bp_size + (exceptions_count * (sizeof(EXACT_TYPE) + AlpConstants::EXCEPTION_POSITION_SIZE)) + + AlpConstants::EXPONENT_SIZE + AlpConstants::FACTOR_SIZE + AlpConstants::EXCEPTIONS_COUNT_SIZE + + AlpConstants::FOR_SIZE + AlpConstants::BIT_WIDTH_SIZE; + + return required_space; + } }; template struct AlpCompression { - using State = AlpCompressionState; + using CompressionData = AlpCompressionData; static constexpr uint8_t EXACT_TYPE_BITSIZE = sizeof(T) * 8; /* @@ -198,8 +208,8 @@ struct AlpCompression { * This function is called once per segment * This operates over ALP first level samples */ - static void FindTopKCombinations(const vector> &vectors_sampled, State &state) { - state.ResetCombinations(); + static void FindTopKCombinations(const vector> &vectors_sampled, CompressionData &compression_data) { + compression_data.ResetCombinations(); unordered_map best_k_combinations_hash; @@ -244,7 +254,7 @@ struct AlpCompression { // Save k' best combinations for (idx_t i = 0; i < MinValue(AlpConstants::MAX_COMBINATIONS, (uint8_t)best_k_combinations.size()); i++) { - state.best_k_combinations.push_back(best_k_combinations[i]); + compression_data.best_k_combinations.push_back(best_k_combinations[i]); } } @@ -252,7 +262,7 @@ struct AlpCompression { * Find the best combination of factor-exponent for a vector from within the best k combinations * This is ALP second level sampling */ - static void FindBestFactorAndExponent(const T *input_vector, idx_t n_values, State &state) { + static void FindBestFactorAndExponent(const T *input_vector, idx_t n_values, CompressionData &compression_data) { //! We sample equidistant values within a vector; to do this we skip a fixed number of values vector vector_sample; auto idx_increments = MaxValue( @@ -266,7 +276,7 @@ struct AlpCompression { idx_t worse_total_bits_counter = 0; //! We try each K combination in search for the one which minimize the compression size in the vector - for (auto &combination : state.best_k_combinations) { + for (auto &combination : compression_data.best_k_combinations) { uint64_t estimated_compression_size = DryCompressToEstimateSize(vector_sample, combination.encoding_indices); @@ -284,18 +294,18 @@ struct AlpCompression { best_encoding_indices = combination.encoding_indices; worse_total_bits_counter = 0; } - state.vector_encoding_indices = best_encoding_indices; + compression_data.vector_encoding_indices = best_encoding_indices; } /* * ALP Compress */ static void Compress(const T *input_vector, idx_t n_values, const uint16_t *vector_null_positions, - idx_t nulls_count, State &state) { - if (state.best_k_combinations.size() > 1) { - FindBestFactorAndExponent(input_vector, n_values, state); + idx_t nulls_count, CompressionData &compression_data) { + if (compression_data.best_k_combinations.size() > 1) { + FindBestFactorAndExponent(input_vector, n_values, compression_data); } else { - state.vector_encoding_indices = state.best_k_combinations[0].encoding_indices; + compression_data.vector_encoding_indices = compression_data.best_k_combinations[0].encoding_indices; } // Encoding Floating-Point to Int64 @@ -303,48 +313,48 @@ struct AlpCompression { uint16_t exceptions_idx = 0; for (idx_t i = 0; i < n_values; i++) { T actual_value = input_vector[i]; - int64_t encoded_value = EncodeValue(actual_value, state.vector_encoding_indices); - T decoded_value = DecodeValue(encoded_value, state.vector_encoding_indices); - state.encoded_integers[i] = encoded_value; + int64_t encoded_value = EncodeValue(actual_value, compression_data.vector_encoding_indices); + T decoded_value = DecodeValue(encoded_value, compression_data.vector_encoding_indices); + compression_data.encoded_integers[i] = encoded_value; //! We detect exceptions using a predicated comparison auto is_exception = (decoded_value != actual_value); - state.exceptions_positions[exceptions_idx] = UnsafeNumericCast(i); + compression_data.exceptions_positions[exceptions_idx] = UnsafeNumericCast(i); exceptions_idx += is_exception; } // Finding first non exception value int64_t a_non_exception_value = 0; for (idx_t i = 0; i < n_values; i++) { - if (i != state.exceptions_positions[i]) { - a_non_exception_value = state.encoded_integers[i]; + if (i != compression_data.exceptions_positions[i]) { + a_non_exception_value = compression_data.encoded_integers[i]; break; } } // Replacing that first non exception value on the vector exceptions for (idx_t i = 0; i < exceptions_idx; i++) { - idx_t exception_pos = state.exceptions_positions[i]; + idx_t exception_pos = compression_data.exceptions_positions[i]; T actual_value = input_vector[exception_pos]; - state.encoded_integers[exception_pos] = a_non_exception_value; - state.exceptions[i] = actual_value; + compression_data.encoded_integers[exception_pos] = a_non_exception_value; + compression_data.exceptions[i] = actual_value; } - state.exceptions_count = exceptions_idx; + compression_data.exceptions_count = exceptions_idx; // Replacing nulls with that first non exception value for (idx_t i = 0; i < nulls_count; i++) { uint16_t null_value_pos = vector_null_positions[i]; - state.encoded_integers[null_value_pos] = a_non_exception_value; + compression_data.encoded_integers[null_value_pos] = a_non_exception_value; } // Analyze FFOR auto min_value = NumericLimits::Maximum(); auto max_value = NumericLimits::Minimum(); for (idx_t i = 0; i < n_values; i++) { - max_value = MaxValue(max_value, state.encoded_integers[i]); - min_value = MinValue(min_value, state.encoded_integers[i]); + max_value = MaxValue(max_value, compression_data.encoded_integers[i]); + min_value = MinValue(min_value, compression_data.encoded_integers[i]); } uint64_t min_max_diff = (static_cast(max_value) - static_cast(min_value)); - auto *u_encoded_integers = reinterpret_cast(state.encoded_integers); + auto *u_encoded_integers = reinterpret_cast(compression_data.encoded_integers); auto const u_min_value = static_cast(min_value); // Subtract FOR @@ -357,19 +367,19 @@ struct AlpCompression { auto bit_width = BitpackingPrimitives::MinimumBitWidth(min_max_diff); auto bp_size = BitpackingPrimitives::GetRequiredSize(n_values, bit_width); if (!EMPTY && bit_width > 0) { //! We only execute the BP if we are writing the data - BitpackingPrimitives::PackBuffer(state.values_encoded, u_encoded_integers, n_values, - bit_width); + BitpackingPrimitives::PackBuffer(compression_data.values_encoded, u_encoded_integers, + n_values, bit_width); } - state.bit_width = bit_width; // in bits - state.bp_size = bp_size; // in bytes - state.frame_of_reference = static_cast(min_value); // understood this can be negative + compression_data.bit_width = bit_width; // in bits + compression_data.bp_size = bp_size; // in bytes + compression_data.frame_of_reference = static_cast(min_value); // understood this can be negative } /* * Overload without specifying nulls */ - static void Compress(const T *input_vector, idx_t n_values, State &state) { - Compress(input_vector, n_values, nullptr, 0, state); + static void Compress(const T *input_vector, idx_t n_values, CompressionData &compression_data) { + Compress(input_vector, n_values, nullptr, 0, compression_data); } }; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp index bac590d0e..015146b62 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/storage_manager.hpp" #include "duckdb/storage/compression/alp/algorithm/alp.hpp" #include "duckdb/storage/compression/alp/alp_constants.hpp" #include "duckdb/storage/compression/alp/alp_utils.hpp" @@ -24,7 +25,7 @@ struct AlpAnalyzeState : public AnalyzeState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; - explicit AlpAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), state() { + explicit AlpAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), compression_data() { } idx_t total_bytes_used = 0; @@ -34,7 +35,8 @@ struct AlpAnalyzeState : public AnalyzeState { idx_t vectors_count = 0; vector> rowgroup_sample; vector> complete_vectors_sampled; - alp::AlpCompressionState state; + alp::AlpCompressionData compression_data; + idx_t storage_version = 0; public: // Returns the required space to hyphotetically store the compressed segment @@ -44,23 +46,14 @@ struct AlpAnalyzeState : public AnalyzeState { current_bytes_used_in_segment = 0; } - // Returns the required space to hyphotetically store the compressed vector - idx_t RequiredSpace() const { - idx_t required_space = - state.bp_size + state.exceptions_count * (sizeof(EXACT_TYPE) + AlpConstants::EXCEPTION_POSITION_SIZE) + - AlpConstants::EXPONENT_SIZE + AlpConstants::FACTOR_SIZE + AlpConstants::EXCEPTIONS_COUNT_SIZE + - AlpConstants::FOR_SIZE + AlpConstants::BIT_WIDTH_SIZE + AlpConstants::METADATA_POINTER_SIZE; - return required_space; - } - - void FlushVector() { - current_bytes_used_in_segment += RequiredSpace(); - state.Reset(); + void FlushVector(idx_t vector_size) { + current_bytes_used_in_segment += vector_size; + compression_data.Reset(); } // Check if we have enough space in the segment to hyphotetically store the compressed vector - bool HasEnoughSpace() { - idx_t bytes_to_be_used = AlignValue(current_bytes_used_in_segment + RequiredSpace()); + bool HasEnoughSpace(idx_t vector_size) { + idx_t bytes_to_be_used = AlignValue(current_bytes_used_in_segment + vector_size); // We have enough space if the already used space + the required space for a new vector // does not exceed the space of the block - the segment header (the pointer to the metadata) return bytes_to_be_used <= (info.GetBlockSize() - AlpConstants::METADATA_POINTER_SIZE); @@ -74,7 +67,9 @@ struct AlpAnalyzeState : public AnalyzeState { template unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) { CompressionInfo info(col_data.GetBlockManager()); - return make_uniq>(info); + auto state = make_uniq>(info); + state->storage_version = col_data.GetStorageManager().GetStorageVersion(); + return unique_ptr(std::move(state)); } /* @@ -82,7 +77,12 @@ unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) */ template bool AlpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = (AlpAnalyzeState &)state; + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + + auto &analyze_state = state.Cast>(); + bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); analyze_state.vectors_count += 1; @@ -149,17 +149,24 @@ idx_t AlpFinalAnalyze(AnalyzeState &state) { auto &analyze_state = (AlpAnalyzeState &)state; // Finding the Top K combinations of Exponent and Factor - alp::AlpCompression::FindTopKCombinations(analyze_state.rowgroup_sample, analyze_state.state); + alp::AlpCompression::FindTopKCombinations(analyze_state.rowgroup_sample, analyze_state.compression_data); // Encode the entire sampled vectors to estimate a compression size idx_t compressed_values = 0; for (auto &vector_to_compress : analyze_state.complete_vectors_sampled) { alp::AlpCompression::Compress(vector_to_compress.data(), vector_to_compress.size(), - analyze_state.state); - if (!analyze_state.HasEnoughSpace()) { + analyze_state.compression_data); + const idx_t uncompressed_size = AlpConstants::EXPONENT_SIZE + sizeof(T) * vector_to_compress.size(); + const idx_t compressed_size = analyze_state.compression_data.RequiredSpace(); + const bool should_compress = compressed_size < uncompressed_size || analyze_state.storage_version < 7; + + const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; + + if (!analyze_state.HasEnoughSpace(vector_size)) { analyze_state.FlushSegment(); } - analyze_state.FlushVector(); + analyze_state.FlushVector(vector_size); + compressed_values += vector_to_compress.size(); } diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp index e08b8b5bb..9fd341331 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp @@ -28,17 +28,16 @@ namespace duckdb { template struct AlpCompressionState : public CompressionState { - public: using EXACT_TYPE = typename FloatingToExact::TYPE; AlpCompressionState(ColumnDataCheckpointData &checkpoint_data, AlpAnalyzeState *analyze_state) : CompressionState(analyze_state->info), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ALP)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); //! Combinations found on the analyze step are needed for compression - state.best_k_combinations = analyze_state->state.best_k_combinations; + compression_data.best_k_combinations = analyze_state->compression_data.best_k_combinations; } ColumnDataCheckpointData &checkpoint_data; @@ -55,10 +54,10 @@ struct AlpCompressionState : public CompressionState { data_ptr_t metadata_ptr; // Reverse pointer to the next free spot for the metadata; used in decoding to SKIP vectors uint32_t next_vector_byte_index_start = AlpConstants::HEADER_SIZE; - T input_vector[AlpConstants::ALP_VECTOR_SIZE]; + T input_vector[AlpConstants::ALP_VECTOR_SIZE]; // Uncompressed data uint16_t vector_null_positions[AlpConstants::ALP_VECTOR_SIZE]; - alp::AlpCompressionState state; + alp::AlpCompressionData compression_data; public: // Returns the space currently used in the segment (in bytes) @@ -66,19 +65,10 @@ struct AlpCompressionState : public CompressionState { return AlpConstants::METADATA_POINTER_SIZE + data_bytes_used; } - // Returns the required space to store the newly compressed vector - idx_t RequiredSpace() { - idx_t required_space = - state.bp_size + (state.exceptions_count * (sizeof(EXACT_TYPE) + AlpConstants::EXCEPTION_POSITION_SIZE)) + - AlpConstants::EXPONENT_SIZE + AlpConstants::FACTOR_SIZE + AlpConstants::EXCEPTIONS_COUNT_SIZE + - AlpConstants::FOR_SIZE + AlpConstants::BIT_WIDTH_SIZE; - return required_space; - } - - bool HasEnoughSpace() { + bool HasEnoughSpace(idx_t vector_size) { //! If [start of block + used space + required space] is more than whats left (current position //! of metadata pointer - the size of a new metadata pointer) - if ((handle.Ptr() + AlignValue(UsedSpace() + RequiredSpace())) >= + if ((handle.Ptr() + AlignValue(UsedSpace() + vector_size)) >= (metadata_ptr - AlpConstants::METADATA_POINTER_SIZE)) { return false; } @@ -86,15 +76,15 @@ struct AlpCompressionState : public CompressionState { } void ResetVector() { - state.Reset(); + compression_data.Reset(); } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); @@ -111,58 +101,101 @@ struct AlpCompressionState : public CompressionState { if (nulls_idx) { alp::AlpUtils::FindAndReplaceNullsInVector(input_vector, vector_null_positions, vector_idx, nulls_idx); } - alp::AlpCompression::Compress(input_vector, vector_idx, vector_null_positions, nulls_idx, state); + alp::AlpCompression::Compress(input_vector, vector_idx, vector_null_positions, nulls_idx, + compression_data); + const idx_t uncompressed_size = AlpConstants::EXPONENT_SIZE + sizeof(T) * vector_idx; + const idx_t compressed_size = compression_data.RequiredSpace(); + + const auto storage_version = checkpoint_data.GetStorageManager().GetStorageVersion(); + const bool should_compress = compressed_size < uncompressed_size || storage_version < 7; + + const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; + //! Check if the compressed vector fits on current segment - if (!HasEnoughSpace()) { - auto row_start = current_segment->start + current_segment->count; + if (!HasEnoughSpace(vector_size)) { FlushSegment(); - CreateEmptySegment(row_start); + CreateEmptySegment(); } + if (nulls_idx) { + current_segment->stats.statistics.SetHasNullFast(); + } if (vector_idx != nulls_idx) { //! At least there is one valid value in the vector + current_segment->stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < vector_idx; i++) { current_segment->stats.statistics.UpdateNumericStats(input_vector[i]); } } current_segment->count += vector_idx; - FlushVector(); + + if (should_compress) { + FlushCompressedVector(); + } else { + FlushUncompressedVector(); + } } // Stores the vector and its metadata - void FlushVector() { - Store(state.vector_encoding_indices.exponent, data_ptr); + void FlushCompressedVector() { + Store(compression_data.vector_encoding_indices.exponent, data_ptr); data_ptr += AlpConstants::EXPONENT_SIZE; - Store(state.vector_encoding_indices.factor, data_ptr); + Store(compression_data.vector_encoding_indices.factor, data_ptr); data_ptr += AlpConstants::FACTOR_SIZE; - Store(state.exceptions_count, data_ptr); + Store(compression_data.exceptions_count, data_ptr); data_ptr += AlpConstants::EXCEPTIONS_COUNT_SIZE; - Store(state.frame_of_reference, data_ptr); + Store(compression_data.frame_of_reference, data_ptr); data_ptr += AlpConstants::FOR_SIZE; - Store(UnsafeNumericCast(state.bit_width), data_ptr); + Store(UnsafeNumericCast(compression_data.bit_width), data_ptr); data_ptr += AlpConstants::BIT_WIDTH_SIZE; - memcpy((void *)data_ptr, (void *)state.values_encoded, state.bp_size); + memcpy((void *)data_ptr, (void *)compression_data.values_encoded, compression_data.bp_size); // We should never go out of bounds in the values_encoded array - D_ASSERT((AlpConstants::ALP_VECTOR_SIZE * 8) >= state.bp_size); + D_ASSERT((AlpConstants::ALP_VECTOR_SIZE * 8) >= compression_data.bp_size); - data_ptr += state.bp_size; + data_ptr += compression_data.bp_size; - if (state.exceptions_count > 0) { - memcpy((void *)data_ptr, (void *)state.exceptions, sizeof(EXACT_TYPE) * state.exceptions_count); - data_ptr += sizeof(EXACT_TYPE) * state.exceptions_count; - memcpy((void *)data_ptr, (void *)state.exceptions_positions, - AlpConstants::EXCEPTION_POSITION_SIZE * state.exceptions_count); - data_ptr += AlpConstants::EXCEPTION_POSITION_SIZE * state.exceptions_count; + if (compression_data.exceptions_count > 0) { + memcpy((void *)data_ptr, (void *)compression_data.exceptions, + sizeof(EXACT_TYPE) * compression_data.exceptions_count); + data_ptr += sizeof(EXACT_TYPE) * compression_data.exceptions_count; + memcpy((void *)data_ptr, (void *)compression_data.exceptions_positions, + AlpConstants::EXCEPTION_POSITION_SIZE * compression_data.exceptions_count); + data_ptr += AlpConstants::EXCEPTION_POSITION_SIZE * compression_data.exceptions_count; } - data_bytes_used += state.bp_size + - (state.exceptions_count * (sizeof(EXACT_TYPE) + AlpConstants::EXCEPTION_POSITION_SIZE)) + - AlpConstants::EXPONENT_SIZE + AlpConstants::FACTOR_SIZE + - AlpConstants::EXCEPTIONS_COUNT_SIZE + AlpConstants::FOR_SIZE + AlpConstants::BIT_WIDTH_SIZE; + data_bytes_used += + compression_data.bp_size + + (compression_data.exceptions_count * (sizeof(EXACT_TYPE) + AlpConstants::EXCEPTION_POSITION_SIZE)) + + AlpConstants::EXPONENT_SIZE + AlpConstants::FACTOR_SIZE + AlpConstants::EXCEPTIONS_COUNT_SIZE + + AlpConstants::FOR_SIZE + AlpConstants::BIT_WIDTH_SIZE; + + // Write pointer to the vector data (metadata) + metadata_ptr -= sizeof(uint32_t); + Store(next_vector_byte_index_start, metadata_ptr); + next_vector_byte_index_start = NumericCast(UsedSpace()); + + vectors_flushed++; + vector_idx = 0; + nulls_idx = 0; + ResetVector(); + } + + // Uncompressed mode + void FlushUncompressedVector() { + // Store a sentinel value instead of the exponent, signaling the coming data is stored uncompressed. + constexpr uint8_t sentinel = AlpConstants::UNCOMPRESSED_MODE_SENTINEL; + Store(sentinel, data_ptr); + data_ptr += AlpConstants::EXPONENT_SIZE; + + // Store uncompressed data + memcpy(data_ptr, input_vector, sizeof(T) * vector_idx); + data_ptr += sizeof(T) * vector_idx; + + data_bytes_used += AlpConstants::EXPONENT_SIZE + (sizeof(T) * vector_idx); // Write pointer to the vector data (metadata) metadata_ptr -= sizeof(uint32_t); @@ -222,7 +255,8 @@ struct AlpCompressionState : public CompressionState { FlushSegment(); current_segment.reset(); } - + //! Stages uncompressed input values into fixed-size batches (ALP_VECTOR_SIZE), calling CompressVector() to + //! compress and flush each full batch to the segment. Handles nulls and processes arbitrarily large inputs. void Append(UnifiedVectorFormat &vdata, idx_t count) { auto data = UnifiedVectorFormat::GetData(vdata); idx_t values_left_in_data = count; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp index cf9766177..86b399cea 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_constants.hpp @@ -22,7 +22,9 @@ class AlpConstants { static constexpr uint32_t RG_SAMPLES_DUCKDB_JUMP = (DEFAULT_ROW_GROUP_SIZE / RG_SAMPLES) / STANDARD_VECTOR_SIZE; static constexpr uint8_t HEADER_SIZE = sizeof(uint32_t); + //! exponent can store the UNCOMPRESSED_MODE_SENTINEL value static constexpr uint8_t EXPONENT_SIZE = sizeof(uint8_t); + static constexpr uint8_t UNCOMPRESSED_MODE_SENTINEL = std::numeric_limits::max(); static constexpr uint8_t FACTOR_SIZE = sizeof(uint8_t); static constexpr uint8_t EXCEPTIONS_COUNT_SIZE = sizeof(uint16_t); static constexpr uint8_t EXCEPTION_POSITION_SIZE = sizeof(uint16_t); @@ -66,7 +68,6 @@ struct AlpTypedConstants {}; template <> struct AlpTypedConstants { - static constexpr float MAGIC_NUMBER = 12582912.0; //! 2^22 + 2^23 static constexpr uint8_t MAX_EXPONENT = 10; @@ -80,7 +81,6 @@ struct AlpTypedConstants { template <> struct AlpTypedConstants { - static constexpr double MAGIC_NUMBER = 6755399441055744.0; //! 2^51 + 2^52 static constexpr uint8_t MAX_EXPONENT = 18; //! 10^18 is the maximum int64 diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp index 28b52b848..6bc261003 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp @@ -140,6 +140,14 @@ struct AlpScanState : public SegmentScanState { vector_state.v_exponent = Load(vector_ptr); vector_ptr += AlpConstants::EXPONENT_SIZE; + const bool uncompressed_mode = vector_state.v_exponent == AlpConstants::UNCOMPRESSED_MODE_SENTINEL; + if (uncompressed_mode) { + if (!SKIP) { + // Read uncompressed values + memcpy(value_buffer, vector_ptr, sizeof(T) * vector_size); + } + return; + } vector_state.v_factor = Load(vector_ptr); vector_ptr += AlpConstants::FACTOR_SIZE; @@ -153,7 +161,6 @@ struct AlpScanState : public SegmentScanState { vector_ptr += AlpConstants::BIT_WIDTH_SIZE; D_ASSERT(vector_state.exceptions_count <= vector_size); - D_ASSERT(vector_state.v_exponent <= AlpTypedConstants::MAX_EXPONENT); D_ASSERT(vector_state.v_factor <= vector_state.v_exponent); D_ASSERT(vector_state.bit_width <= sizeof(uint64_t) * 8); @@ -201,7 +208,7 @@ struct AlpScanState : public SegmentScanState { }; template -unique_ptr AlpInitScan(ColumnSegment &segment) { +unique_ptr AlpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp index b9c0e6eab..1fa3b3664 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_utils.hpp @@ -36,7 +36,6 @@ class AlpUtils { public: static AlpSamplingParameters GetSamplingParameters(idx_t current_vector_n_values) { - auto n_lookup_values = NumericCast(MinValue(current_vector_n_values, (idx_t)AlpConstants::ALP_VECTOR_SIZE)); //! We sample equidistant values within a vector; to do this we jump a fixed number of values diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp index 1dac66f4a..720617d23 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp @@ -32,11 +32,11 @@ struct AlpRDLeftPartInfo { }; template -class AlpRDCompressionState { +class AlpRDCompressionData { public: using EXACT_TYPE = typename FloatingToExact::TYPE; - AlpRDCompressionState() : right_bit_width(0), left_bit_width(0), exceptions_count(0) { + AlpRDCompressionData() : right_bit_width(0), left_bit_width(0), exceptions_count(0) { } void Reset() { @@ -58,11 +58,19 @@ class AlpRDCompressionState { idx_t right_bit_packed_size; unordered_map left_parts_dict_map; uint8_t actual_dictionary_size; + + idx_t RequiredSpace() const { + const idx_t required_space = left_bit_packed_size + right_bit_packed_size + + static_cast(exceptions_count) * + (AlpRDConstants::EXCEPTION_SIZE + AlpRDConstants::EXCEPTION_POSITION_SIZE) + + AlpRDConstants::EXCEPTIONS_COUNT_SIZE; + return required_space; + } }; template struct AlpRDCompression { - using State = AlpRDCompressionState; + using CompressionData = AlpRDCompressionData; using EXACT_TYPE = typename FloatingToExact::TYPE; static constexpr uint8_t EXACT_TYPE_BITSIZE = sizeof(EXACT_TYPE) * 8; @@ -79,7 +87,8 @@ struct AlpRDCompression { } template - static double BuildLeftPartsDictionary(const vector &values, uint8_t right_bit_width, State &state) { + static double BuildLeftPartsDictionary(const vector &values, uint8_t right_bit_width, + CompressionData &compression_data) { unordered_map left_parts_hash; vector left_parts_sorted_repetitions; @@ -112,21 +121,21 @@ struct AlpRDCompression { if (PERSIST_DICT) { for (idx_t dict_idx = 0; dict_idx < actual_dictionary_size; dict_idx++) { //! The dict keys are mapped to the left part themselves - state.left_parts_dict[dict_idx] = + compression_data.left_parts_dict[dict_idx] = UnsafeNumericCast(left_parts_sorted_repetitions[dict_idx].hash); - state.left_parts_dict_map.insert({state.left_parts_dict[dict_idx], dict_idx}); + compression_data.left_parts_dict_map.insert({compression_data.left_parts_dict[dict_idx], dict_idx}); } //! Pararelly we store a map of the dictionary to quickly resolve exceptions during encoding for (idx_t i = actual_dictionary_size + 1; i < left_parts_sorted_repetitions.size(); i++) { - state.left_parts_dict_map.insert({left_parts_sorted_repetitions[i].hash, i}); + compression_data.left_parts_dict_map.insert({left_parts_sorted_repetitions[i].hash, i}); } - state.left_bit_width = left_bit_width; - state.right_bit_width = right_bit_width; - state.actual_dictionary_size = UnsafeNumericCast(actual_dictionary_size); + compression_data.left_bit_width = left_bit_width; + compression_data.right_bit_width = right_bit_width; + compression_data.actual_dictionary_size = UnsafeNumericCast(actual_dictionary_size); - D_ASSERT(state.left_bit_width > 0 && state.right_bit_width > 0 && - state.left_bit_width <= AlpRDConstants::MAX_DICTIONARY_BIT_WIDTH && - state.actual_dictionary_size <= AlpRDConstants::MAX_DICTIONARY_SIZE); + D_ASSERT(compression_data.left_bit_width > 0 && compression_data.right_bit_width > 0 && + compression_data.left_bit_width <= AlpRDConstants::MAX_DICTIONARY_BIT_WIDTH && + compression_data.actual_dictionary_size <= AlpRDConstants::MAX_DICTIONARY_SIZE); } double estimated_size = EstimateCompressionSize(right_bit_width, left_bit_width, @@ -134,68 +143,70 @@ struct AlpRDCompression { return estimated_size; } - static double FindBestDictionary(const vector &values, State &state) { + static double FindBestDictionary(const vector &values, CompressionData &compression_data) { uint8_t right_bit_width = 0; double best_dict_size = NumericLimits::Maximum(); //! Finding the best position to CUT the values for (idx_t i = 1; i <= AlpRDConstants::CUTTING_LIMIT; i++) { uint8_t candidate_right_bit_width = UnsafeNumericCast(EXACT_TYPE_BITSIZE - i); - double estimated_size = BuildLeftPartsDictionary(values, candidate_right_bit_width, state); + double estimated_size = + BuildLeftPartsDictionary(values, candidate_right_bit_width, compression_data); if (estimated_size <= best_dict_size) { right_bit_width = candidate_right_bit_width; best_dict_size = estimated_size; } // TODO: We could implement an early exit mechanism similar to normal ALP } - double estimated_size = BuildLeftPartsDictionary(values, right_bit_width, state); + double estimated_size = BuildLeftPartsDictionary(values, right_bit_width, compression_data); return estimated_size; } - static void Compress(const EXACT_TYPE *input_vector, idx_t n_values, State &state) { - + static void Compress(const EXACT_TYPE *input_vector, idx_t n_values, CompressionData &compression_data) { uint64_t right_parts[AlpRDConstants::ALP_VECTOR_SIZE]; uint16_t left_parts[AlpRDConstants::ALP_VECTOR_SIZE]; // Cutting the floating point values for (idx_t i = 0; i < n_values; i++) { EXACT_TYPE tmp = input_vector[i]; - right_parts[i] = tmp & ((1ULL << state.right_bit_width) - 1); - left_parts[i] = UnsafeNumericCast(tmp >> state.right_bit_width); + right_parts[i] = tmp & ((1ULL << compression_data.right_bit_width) - 1); + left_parts[i] = UnsafeNumericCast(tmp >> compression_data.right_bit_width); } // Dictionary encoding for left parts for (idx_t i = 0; i < n_values; i++) { uint16_t dictionary_index; auto dictionary_key = left_parts[i]; - if (state.left_parts_dict_map.find(dictionary_key) == state.left_parts_dict_map.end()) { + if (compression_data.left_parts_dict_map.find(dictionary_key) == + compression_data.left_parts_dict_map.end()) { //! If not found on the dictionary we store the smallest non-key index as exception (the dict size) - dictionary_index = state.actual_dictionary_size; + dictionary_index = compression_data.actual_dictionary_size; } else { - dictionary_index = state.left_parts_dict_map[dictionary_key]; + dictionary_index = compression_data.left_parts_dict_map[dictionary_key]; } left_parts[i] = dictionary_index; //! Left parts not found in the dictionary are stored as exceptions - if (dictionary_index >= state.actual_dictionary_size) { - state.exceptions[state.exceptions_count] = dictionary_key; - state.exceptions_positions[state.exceptions_count] = UnsafeNumericCast(i); - state.exceptions_count++; + if (dictionary_index >= compression_data.actual_dictionary_size) { + compression_data.exceptions[compression_data.exceptions_count] = dictionary_key; + compression_data.exceptions_positions[compression_data.exceptions_count] = + UnsafeNumericCast(i); + compression_data.exceptions_count++; } } - auto right_bit_packed_size = BitpackingPrimitives::GetRequiredSize(n_values, state.right_bit_width); - auto left_bit_packed_size = BitpackingPrimitives::GetRequiredSize(n_values, state.left_bit_width); + auto right_bit_packed_size = BitpackingPrimitives::GetRequiredSize(n_values, compression_data.right_bit_width); + auto left_bit_packed_size = BitpackingPrimitives::GetRequiredSize(n_values, compression_data.left_bit_width); if (!EMPTY) { // Bitpacking Left and Right parts - BitpackingPrimitives::PackBuffer(state.left_parts_encoded, left_parts, n_values, - state.left_bit_width); - BitpackingPrimitives::PackBuffer(state.right_parts_encoded, right_parts, n_values, - state.right_bit_width); + BitpackingPrimitives::PackBuffer(compression_data.left_parts_encoded, left_parts, n_values, + compression_data.left_bit_width); + BitpackingPrimitives::PackBuffer(compression_data.right_parts_encoded, right_parts, + n_values, compression_data.right_bit_width); } - state.left_bit_packed_size = left_bit_packed_size; - state.right_bit_packed_size = right_bit_packed_size; + compression_data.left_bit_packed_size = left_bit_packed_size; + compression_data.right_bit_packed_size = right_bit_packed_size; } }; @@ -207,7 +218,6 @@ struct AlpRDDecompression { EXACT_TYPE *output, idx_t values_count, uint16_t exceptions_count, const uint16_t *exceptions, const uint16_t *exceptions_positions, uint8_t left_bit_width, uint8_t right_bit_width) { - uint8_t left_decoded[AlpRDConstants::ALP_VECTOR_SIZE * 8] = {0}; uint8_t right_decoded[AlpRDConstants::ALP_VECTOR_SIZE * 8] = {0}; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp index 25901667e..2fb1fc874 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp @@ -26,14 +26,14 @@ struct AlpRDAnalyzeState : public AnalyzeState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; - explicit AlpRDAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), state() { + explicit AlpRDAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), compression_data() { } idx_t vectors_count = 0; idx_t total_values_count = 0; idx_t vectors_sampled_count = 0; vector rowgroup_sample; - alp::AlpRDCompressionState state; + alp::AlpRDCompressionData compression_data; }; template @@ -47,8 +47,12 @@ unique_ptr AlpRDInitAnalyze(ColumnData &col_data, PhysicalType typ */ template bool AlpRDAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + using EXACT_TYPE = typename FloatingToExact::TYPE; - auto &analyze_state = (AlpRDAnalyzeState &)state; + auto &analyze_state = state.Cast>(); bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); @@ -118,11 +122,11 @@ idx_t AlpRDFinalAnalyze(AnalyzeState &state) { static_cast(analyze_state.total_values_count)); // Finding which is the best dictionary for the sample - double estimated_bits_per_value = - alp::AlpRDCompression::FindBestDictionary(analyze_state.rowgroup_sample, analyze_state.state); + double estimated_bits_per_value = alp::AlpRDCompression::FindBestDictionary( + analyze_state.rowgroup_sample, analyze_state.compression_data); double estimated_compressed_bits = estimated_bits_per_value * static_cast(analyze_state.rowgroup_sample.size()); - double estimed_compressed_bytes = estimated_compressed_bits / 8; + double estimated_compressed_bytes = estimated_compressed_bits / 8; //! Overhead per segment: [Pointer to metadata + right bitwidth + left bitwidth + n dict elems] + Dictionary Size double per_segment_overhead = AlpRDConstants::HEADER_SIZE + AlpRDConstants::MAX_DICTIONARY_SIZE_BYTES; @@ -133,7 +137,7 @@ idx_t AlpRDFinalAnalyze(AnalyzeState &state) { uint32_t n_vectors = LossyNumericCast( std::ceil((double)analyze_state.total_values_count / AlpRDConstants::ALP_VECTOR_SIZE)); - auto estimated_size = (estimed_compressed_bytes * factor_of_sampling) + (n_vectors * per_vector_overhead); + auto estimated_size = (estimated_compressed_bytes * factor_of_sampling) + (n_vectors * per_vector_overhead); uint32_t estimated_n_blocks = LossyNumericCast( std::ceil(estimated_size / (static_cast(state.info.GetBlockSize()) - per_segment_overhead))); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp index 86559d604..ef12c6985 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp @@ -30,7 +30,6 @@ namespace duckdb { template struct AlpRDCompressionState : public CompressionState { - public: using EXACT_TYPE = typename FloatingToExact::TYPE; @@ -38,15 +37,16 @@ struct AlpRDCompressionState : public CompressionState { : CompressionState(analyze_state->info), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ALPRD)) { //! State variables from the analyze step that are needed for compression - state.left_parts_dict_map = std::move(analyze_state->state.left_parts_dict_map); - state.left_bit_width = analyze_state->state.left_bit_width; - state.right_bit_width = analyze_state->state.right_bit_width; - state.actual_dictionary_size = analyze_state->state.actual_dictionary_size; - actual_dictionary_size_bytes = state.actual_dictionary_size * AlpRDConstants::DICTIONARY_ELEMENT_SIZE; + compression_data.left_parts_dict_map = std::move(analyze_state->compression_data.left_parts_dict_map); + compression_data.left_bit_width = analyze_state->compression_data.left_bit_width; + compression_data.right_bit_width = analyze_state->compression_data.right_bit_width; + compression_data.actual_dictionary_size = analyze_state->compression_data.actual_dictionary_size; + actual_dictionary_size_bytes = + compression_data.actual_dictionary_size * AlpRDConstants::DICTIONARY_ELEMENT_SIZE; next_vector_byte_index_start = AlpRDConstants::HEADER_SIZE + actual_dictionary_size_bytes; - memcpy((void *)state.left_parts_dict, (void *)analyze_state->state.left_parts_dict, + memcpy((void *)compression_data.left_parts_dict, (void *)analyze_state->compression_data.left_parts_dict, actual_dictionary_size_bytes); - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); } ColumnDataCheckpointData &checkpoint_data; @@ -67,7 +67,7 @@ struct AlpRDCompressionState : public CompressionState { EXACT_TYPE input_vector[AlpRDConstants::ALP_VECTOR_SIZE]; uint16_t vector_null_positions[AlpRDConstants::ALP_VECTOR_SIZE]; - alp::AlpRDCompressionState state; + alp::AlpRDCompressionData compression_data; public: // Returns the space currently used in the segment (in bytes) @@ -76,19 +76,10 @@ struct AlpRDCompressionState : public CompressionState { return AlpRDConstants::HEADER_SIZE + actual_dictionary_size_bytes + data_bytes_used; } - // Returns the required space to store the newly compressed vector - idx_t RequiredSpace() { - idx_t required_space = - state.left_bit_packed_size + state.right_bit_packed_size + - state.exceptions_count * (AlpRDConstants::EXCEPTION_SIZE + AlpRDConstants::EXCEPTION_POSITION_SIZE) + - AlpRDConstants::EXCEPTIONS_COUNT_SIZE; - return required_space; - } - - bool HasEnoughSpace() { + bool HasEnoughSpace(idx_t vector_size) { //! If [start of block + used space + required space] is more than whats left (current position //! of metadata pointer - the size of a new metadata pointer) - if ((handle.Ptr() + AlignValue(UsedSpace() + RequiredSpace())) >= + if ((handle.Ptr() + AlignValue(UsedSpace() + vector_size)) >= (metadata_ptr - AlpRDConstants::METADATA_POINTER_SIZE)) { return false; } @@ -96,15 +87,15 @@ struct AlpRDCompressionState : public CompressionState { } void ResetVector() { - state.Reset(); + compression_data.Reset(); } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); auto &buffer_manager = BufferManager::GetBufferManager(db); @@ -123,46 +114,88 @@ struct AlpRDCompressionState : public CompressionState { alp::AlpUtils::FindAndReplaceNullsInVector(input_vector, vector_null_positions, vector_idx, nulls_idx); } - alp::AlpRDCompression::Compress(input_vector, vector_idx, state); + alp::AlpRDCompression::Compress(input_vector, vector_idx, compression_data); + + const idx_t uncompressed_size = AlpConstants::EXCEPTIONS_COUNT_SIZE + sizeof(EXACT_TYPE) * vector_idx; + const idx_t compressed_size = compression_data.RequiredSpace(); + + const auto storage_version = checkpoint_data.GetStorageManager().GetStorageVersion(); + const bool should_compress = compressed_size < uncompressed_size || storage_version < 7; + + const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; + //! Check if the compressed vector fits on current segment - if (!HasEnoughSpace()) { - auto row_start = current_segment->start + current_segment->count; + if (!HasEnoughSpace(vector_size)) { FlushSegment(); - CreateEmptySegment(row_start); + CreateEmptySegment(); + } + if (nulls_idx) { + current_segment->stats.statistics.SetHasNullFast(); } if (vector_idx != nulls_idx) { //! At least there is one valid value in the vector + current_segment->stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < vector_idx; i++) { T floating_point_value = Load(const_data_ptr_cast(&input_vector[i])); current_segment->stats.statistics.UpdateNumericStats(floating_point_value); } } current_segment->count += vector_idx; - FlushVector(); + + if (should_compress) { + FlushCompressedVector(); + } else { + FlushUncompressedVector(); + } } // Stores the vector and its metadata - void FlushVector() { - Store(state.exceptions_count, data_ptr); + void FlushCompressedVector() { + Store(compression_data.exceptions_count, data_ptr); data_ptr += AlpRDConstants::EXCEPTIONS_COUNT_SIZE; - memcpy((void *)data_ptr, (void *)state.left_parts_encoded, state.left_bit_packed_size); - data_ptr += state.left_bit_packed_size; + memcpy((void *)data_ptr, (void *)compression_data.left_parts_encoded, compression_data.left_bit_packed_size); + data_ptr += compression_data.left_bit_packed_size; - memcpy((void *)data_ptr, (void *)state.right_parts_encoded, state.right_bit_packed_size); - data_ptr += state.right_bit_packed_size; + memcpy((void *)data_ptr, (void *)compression_data.right_parts_encoded, compression_data.right_bit_packed_size); + data_ptr += compression_data.right_bit_packed_size; - if (state.exceptions_count > 0) { - memcpy((void *)data_ptr, (void *)state.exceptions, AlpRDConstants::EXCEPTION_SIZE * state.exceptions_count); - data_ptr += AlpRDConstants::EXCEPTION_SIZE * state.exceptions_count; - memcpy((void *)data_ptr, (void *)state.exceptions_positions, - AlpRDConstants::EXCEPTION_POSITION_SIZE * state.exceptions_count); - data_ptr += AlpRDConstants::EXCEPTION_POSITION_SIZE * state.exceptions_count; + if (compression_data.exceptions_count > 0) { + memcpy((void *)data_ptr, (void *)compression_data.exceptions, + AlpRDConstants::EXCEPTION_SIZE * compression_data.exceptions_count); + data_ptr += AlpRDConstants::EXCEPTION_SIZE * compression_data.exceptions_count; + memcpy((void *)data_ptr, (void *)compression_data.exceptions_positions, + AlpRDConstants::EXCEPTION_POSITION_SIZE * compression_data.exceptions_count); + data_ptr += AlpRDConstants::EXCEPTION_POSITION_SIZE * compression_data.exceptions_count; } - data_bytes_used += - state.left_bit_packed_size + state.right_bit_packed_size + - (state.exceptions_count * (AlpRDConstants::EXCEPTION_SIZE + AlpRDConstants::EXCEPTION_POSITION_SIZE)) + - AlpRDConstants::EXCEPTIONS_COUNT_SIZE; + data_bytes_used += compression_data.left_bit_packed_size + compression_data.right_bit_packed_size + + (compression_data.exceptions_count * + (AlpRDConstants::EXCEPTION_SIZE + AlpRDConstants::EXCEPTION_POSITION_SIZE)) + + AlpRDConstants::EXCEPTIONS_COUNT_SIZE; + + // Write pointer to the vector data (metadata) + metadata_ptr -= AlpRDConstants::METADATA_POINTER_SIZE; + Store(next_vector_byte_index_start, metadata_ptr); + next_vector_byte_index_start = NumericCast(UsedSpace()); + + vectors_flushed++; + vector_idx = 0; + nulls_idx = 0; + ResetVector(); + } + + //! Uncompressed mode + void FlushUncompressedVector() { + // Store a sentinel value, signaling the coming data is stored uncompressed. + constexpr uint16_t sentinel = AlpRDConstants::UNCOMPRESSED_MODE_SENTINEL; + Store(sentinel, data_ptr); + data_ptr += AlpRDConstants::EXCEPTIONS_COUNT_SIZE; + + // Store uncompressed data + memcpy(data_ptr, input_vector, sizeof(EXACT_TYPE) * vector_idx); + data_ptr += sizeof(EXACT_TYPE) * vector_idx; + + data_bytes_used += AlpConstants::EXCEPTIONS_COUNT_SIZE + (sizeof(EXACT_TYPE) * vector_idx); // Write pointer to the vector data (metadata) metadata_ptr -= AlpRDConstants::METADATA_POINTER_SIZE; @@ -211,19 +244,19 @@ struct AlpRDCompressionState : public CompressionState { dataptr += AlpRDConstants::METADATA_POINTER_SIZE; // Store the right bw for the segment - Store(state.right_bit_width, dataptr); + Store(compression_data.right_bit_width, dataptr); dataptr += AlpRDConstants::RIGHT_BIT_WIDTH_SIZE; // Store the left bw for the segment - Store(state.left_bit_width, dataptr); + Store(compression_data.left_bit_width, dataptr); dataptr += AlpRDConstants::LEFT_BIT_WIDTH_SIZE; // Store the actual number of elements on the dictionary of the segment - Store(state.actual_dictionary_size, dataptr); + Store(compression_data.actual_dictionary_size, dataptr); dataptr += AlpRDConstants::N_DICTIONARY_ELEMENTS_SIZE; // Store the Dictionary - memcpy((void *)dataptr, (void *)state.left_parts_dict, actual_dictionary_size_bytes); + memcpy((void *)dataptr, (void *)compression_data.left_parts_dict, actual_dictionary_size_bytes); checkpoint_state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); data_bytes_used = 0; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_constants.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_constants.hpp index 521551e2d..c0dc6f092 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_constants.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_constants.hpp @@ -8,6 +8,8 @@ #pragma once +#include "duckdb/common/limits.hpp" + namespace duckdb { class AlpRDConstants { @@ -22,7 +24,9 @@ class AlpRDConstants { static constexpr uint8_t EXCEPTION_SIZE = sizeof(uint16_t); static constexpr uint8_t METADATA_POINTER_SIZE = sizeof(uint32_t); + //! exceptions_count can store the UNCOMPRESSED_MODE_SENTINEL value static constexpr uint8_t EXCEPTIONS_COUNT_SIZE = sizeof(uint16_t); + static constexpr uint16_t UNCOMPRESSED_MODE_SENTINEL = std::numeric_limits::max(); static constexpr uint8_t EXCEPTION_POSITION_SIZE = sizeof(uint16_t); static constexpr uint8_t RIGHT_BIT_WIDTH_SIZE = sizeof(uint8_t); static constexpr uint8_t LEFT_BIT_WIDTH_SIZE = sizeof(uint8_t); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp index a3feb94b5..3e2d80efd 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp @@ -158,7 +158,15 @@ struct AlpRDScanState : public SegmentScanState { // Load the vector data vector_state.exceptions_count = Load(vector_ptr); vector_ptr += AlpRDConstants::EXCEPTIONS_COUNT_SIZE; - D_ASSERT(vector_state.exceptions_count <= vector_size); + + const bool uncompressed_mode = vector_state.exceptions_count == AlpRDConstants::UNCOMPRESSED_MODE_SENTINEL; + if (uncompressed_mode) { + if (!SKIP) { + // Read uncompressed values + memcpy(value_buffer, vector_ptr, sizeof(T) * vector_size); + } + return; + } auto left_bp_size = BitpackingPrimitives::GetRequiredSize(vector_size, vector_state.left_bit_width); auto right_bp_size = BitpackingPrimitives::GetRequiredSize(vector_size, vector_state.right_bit_width); @@ -208,7 +216,7 @@ struct AlpRDScanState : public SegmentScanState { }; template -unique_ptr AlpRDInitScan(ColumnSegment &segment) { +unique_ptr AlpRDInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp index 277e30b6a..a2b0566e1 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/chimp128.hpp @@ -31,7 +31,6 @@ namespace duckdb { template struct Chimp128CompressionState { - Chimp128CompressionState() : ring_buffer(), previous_leading_zeros(NumericLimits::Maximum()) { previous_value = 0; } @@ -104,7 +103,6 @@ class Chimp128Compression { } static void CompressValue(CHIMP_TYPE in, State &state) { - auto key = state.ring_buffer.Key(in); CHIMP_TYPE xor_result; uint8_t previous_index; diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp index f5c0d70ba..7a38c065e 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp @@ -34,7 +34,6 @@ struct FlagBufferConstants { // So we can just read/write from left to right template class FlagBuffer { - public: FlagBuffer() : counter(0), buffer(nullptr) { } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp index c4b23cfd9..23376dc1f 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/leading_zero_buffer.hpp @@ -40,7 +40,6 @@ struct LeadingZeroBufferConstants { template class LeadingZeroBuffer { - public: static constexpr uint32_t CHIMP_GROUP_SIZE = 1024; static constexpr uint32_t LEADING_ZERO_BITS_SIZE = 3; diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp index de11979cb..00b27f9e8 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp @@ -185,7 +185,6 @@ struct ChimpScanState : public SegmentScanState { } void LoadGroup(CHIMP_TYPE *value_buffer) { - //! FIXME: If we change the order of this to flag -> leading_zero_blocks -> packed_data //! We can leave out the leading zero block count as well, because it can be derived from //! Extracting all the flags and counting the 3's @@ -252,7 +251,7 @@ struct ChimpScanState : public SegmentScanState { }; template -unique_ptr ChimpInitScan(ColumnSegment &segment) { +unique_ptr ChimpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp index 1bc613a0a..94e26c5b4 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp @@ -1,5 +1,6 @@ #pragma once +#include "duckdb/common/primitive_dictionary.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/storage/compression/dict_fsst/common.hpp" #include "duckdb/storage/compression/dict_fsst/analyze.hpp" @@ -38,7 +39,7 @@ struct DictFSSTCompressionState : public CompressionState { ~DictFSSTCompressionState() override; public: - void CreateEmptySegment(idx_t row_start); + void CreateEmptySegment(); idx_t Finalize(); bool AllUnique() const; @@ -75,7 +76,7 @@ struct DictFSSTCompressionState : public CompressionState { bitpacking_width_t dictionary_indices_width = 0; //! string -> dictionary_index (for lookups) - string_map_t current_string_map; + PrimitiveDictionary current_string_map; //! strings added to the dictionary waiting to be encoded vector dictionary_encoding_buffer; idx_t to_encode_string_sum = 0; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp index 032370f86..1cb377baf 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/decompression.hpp @@ -59,7 +59,7 @@ struct CompressedStringScanState : public SegmentScanState { data_ptr_t dictionary_indices_ptr; data_ptr_t string_lengths_ptr; - buffer_ptr dictionary; + buffer_ptr dictionary; void *decoder = nullptr; bool all_values_inlined = false; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp index 99eb72156..14565e476 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp @@ -21,11 +21,14 @@ struct DictionaryAnalyzeState : public DictionaryCompressionState { bool CalculateSpaceRequirements(bool new_string, idx_t string_size) override; void Flush(bool final = false) override; void Verify() override; + void UpdateMaxUniqueCount(); public: idx_t segment_count; idx_t current_tuple_count; idx_t current_unique_count; + idx_t max_unique_count_across_segments = + 0; // Is used to allocate the dictionary optimally later on at the InitCompression step idx_t current_dict_size; StringHeap heap; string_set_t current_set; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp index 09f1f44bd..be255fea0 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp @@ -1,5 +1,6 @@ #pragma once +#include "duckdb/common/primitive_dictionary.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/storage/compression/dictionary/common.hpp" #include "duckdb/function/compression_function.hpp" @@ -23,10 +24,11 @@ namespace duckdb { //===--------------------------------------------------------------------===// struct DictionaryCompressionCompressState : public DictionaryCompressionState { public: - DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info); + DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info, + idx_t max_unique_count_across_all_segments); public: - void CreateEmptySegment(idx_t row_start); + void CreateEmptySegment(); void Verify() override; bool LookupString(string_t str) override; void AddNewString(string_t str) override; @@ -47,7 +49,7 @@ struct DictionaryCompressionCompressState : public DictionaryCompressionState { data_ptr_t current_end_ptr; // Buffers and map for current segment - string_map_t current_string_map; + PrimitiveDictionary current_string_map; vector index_buffer; vector selection_buffer; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp index 1656ec718..e7381f11a 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/decompression.hpp @@ -41,7 +41,7 @@ struct CompressedStringScanState : public StringScanState { uint32_t *index_buffer_ptr; uint32_t index_buffer_count; - buffer_ptr dictionary; + buffer_ptr dictionary; idx_t dictionary_size; StringDictionaryContainer dict; idx_t block_size; diff --git a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp index 1118f77f2..26af6fa4b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp @@ -58,11 +58,10 @@ class EmptyValidityCompression { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto row_start = checkpoint_data.GetRowGroup().start; auto &info = state.info; - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, *state.function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = ColumnSegment::CreateTransientSegment(db, *state.function, type, info.GetBlockSize(), + info.GetBlockManager()); compressed_segment->count = state.count; if (state.non_nulls != state.count) { compressed_segment->stats.statistics.SetHasNullFast(); @@ -77,7 +76,7 @@ class EmptyValidityCompression { auto &checkpoint_state = checkpoint_data.GetCheckpointState(); checkpoint_state.FlushSegment(std::move(compressed_segment), std::move(handle), 0); } - static unique_ptr InitScan(ColumnSegment &segment) { + static unique_ptr InitScan(const QueryContext &context, ColumnSegment &segment) { return make_uniq(); } static void ScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index b523600e3..4261d2d23 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -204,7 +204,7 @@ struct PatasScanState : public SegmentScanState { }; template -unique_ptr PatasInitScan(ColumnSegment &segment) { +unique_ptr PatasInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp b/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp index da2fb5710..64d3383b1 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp @@ -223,7 +223,11 @@ struct RoaringAnalyzeState : public AnalyzeState { bool HasEnoughSpaceInSegment(idx_t required_space); void FlushSegment(); void FlushContainer(); - void Analyze(Vector &input, idx_t count); + template + void Analyze(Vector &input, idx_t count) { + static_assert(AlwaysFalse>::VALUE, + "No specialization exists for this type"); + } public: unsafe_unique_array bitmask_table; @@ -260,6 +264,10 @@ struct RoaringAnalyzeState : public AnalyzeState { ContainerMetadataCollection metadata_collection; vector container_metadata; }; +template <> +void RoaringAnalyzeState::Analyze(Vector &input, idx_t count); +template <> +void RoaringAnalyzeState::Analyze(Vector &input, idx_t count); //===--------------------------------------------------------------------===// // Compress @@ -334,15 +342,21 @@ struct RoaringCompressState : public CompressionState { public: idx_t GetContainerIndex(); - idx_t GetRemainingSpace(); + idx_t GetUsedDataSpace(); + idx_t GetAvailableSpace(); bool CanStore(idx_t container_size, const ContainerMetadata &metadata); void InitializeContainer(); - void CreateEmptySegment(idx_t row_start); + void CreateEmptySegment(); void FlushSegment(); void Finalize(); void FlushContainer(); void NextContainer(); void Compress(Vector &input, idx_t count); + template + void Compress(Vector &input, idx_t count) { + static_assert(AlwaysFalse>::VALUE, + "No specialization exists for this type"); + } public: unique_ptr owned_analyze_state; @@ -365,6 +379,11 @@ struct RoaringCompressState : public CompressionState { idx_t total_count = 0; }; +template <> +void RoaringCompressState::Compress(Vector &input, idx_t count); +template <> +void RoaringCompressState::Compress(Vector &input, idx_t count); + //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// @@ -613,6 +632,69 @@ struct RoaringScanState : public SegmentScanState { vector data_start_position; }; +//! Boolean BitPacking + +template +static void BitPackBooleans(data_ptr_t dst, const bool *src, const idx_t count, + const ValidityMask *validity_mask = nullptr, BaseStatistics *statistics = nullptr) { + uint8_t byte = 0; + int bit_pos = 0; + uint8_t src_bit = false; + + if (ALL_VALID) { + if (UPDATE_STATS) { + statistics->SetHasNoNullFast(); + } + for (idx_t i = 0; i < count; i++) { + src_bit = src[i]; + + if (UPDATE_STATS) { + statistics->UpdateNumericStats(src_bit); + } + byte |= src_bit << bit_pos; + bit_pos++; + + // flush + if (bit_pos == 8) { + *dst++ = byte; + byte = 0; + bit_pos = 0; + } + } + } else { + bool last_bit_value = false; + for (idx_t i = 0; i < count; i++) { + const uint8_t valid = validity_mask->RowIsValid(i); + src_bit = valid ? src[i] : src_bit; + const uint8_t bit = (src_bit & valid) | (last_bit_value & ~valid); + + byte |= bit << bit_pos; + bit_pos++; + + last_bit_value = src_bit; + + if (UPDATE_STATS) { + if (valid) { + statistics->UpdateNumericStats(src_bit); + statistics->SetHasNoNullFast(); + } else { + statistics->SetHasNullFast(); + } + } + + // flush + if (bit_pos == 8) { + *dst++ = byte; + byte = 0; + bit_pos = 0; + } + } + } + // flush last partial byte + if (bit_pos != 0) { + *dst = byte; + } +} } // namespace roaring } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp index 0ae44d7f3..54f6f239f 100644 --- a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -19,6 +19,7 @@ namespace duckdb { class Serializer; class Deserializer; +class QueryContext; struct ColumnSegmentState { virtual ~ColumnSegmentState() { diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp index bc8727a18..d6b07ed23 100644 --- a/src/duckdb/src/include/duckdb/storage/data_table.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -159,7 +159,7 @@ class DataTable : public enable_shared_from_this { const vector &column_path, DataChunk &updates); //! Fetches an append lock - void AppendLock(TableAppendState &state); + void AppendLock(DuckTransaction &transaction, TableAppendState &state); //! Begin appending structs to this table, obtaining necessary locks, etc void InitializeAppend(DuckTransaction &transaction, TableAppendState &state); //! Append a chunk to the table using the AppendState obtained from InitializeAppend @@ -191,12 +191,13 @@ class DataTable : public enable_shared_from_this { ErrorData AppendToIndexes(optional_ptr delete_indexes, DataChunk &table_chunk, DataChunk &index_chunk, const vector &mapped_column_ids, row_t row_start, const IndexAppendMode index_append_mode); - //! Remove a chunk with the row ids [row_start, ..., row_start + chunk.size()] from all indexes of the table - void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row_t row_start); - //! Remove the chunk with the specified set of row identifiers from all indexes of the table - void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers); + //! Revert a previous append made to indexes in a chunk with the row ids [row_start, ..., row_start + chunk.size()] + void RevertIndexAppend(TableAppendState &state, DataChunk &chunk, row_t row_start); + //! Revert a previous append made to indexes with the given row-ids + void RevertIndexAppend(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers); //! Remove the row identifiers from all the indexes of the table - void RemoveFromIndexes(Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count, + IndexRemovalType removal_type); void SetAsMainTable() { this->version = DataTableVersion::MAIN_TABLE; @@ -222,8 +223,6 @@ class DataTable : public enable_shared_from_this { //! Sets statistics of a physical column within the table void SetDistinct(column_t column_id, unique_ptr distinct_stats); - //! Obtains a shared lock to prevent checkpointing while operations are running - unique_ptr GetSharedCheckpointLock(); //! Obtains a lock during a checkpoint operation that prevents other threads from reading this table unique_ptr GetCheckpointLock(); //! Checkpoint the table to the specified table data writer @@ -234,7 +233,7 @@ class DataTable : public enable_shared_from_this { idx_t ColumnCount() const; idx_t GetTotalRows() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); //! Scans the next chunk for the CREATE INDEX operator bool CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type); diff --git a/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp index 16ed70527..aba2d67a2 100644 --- a/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/in_memory_block_manager.hpp @@ -30,14 +30,17 @@ class InMemoryBlockManager : public BlockManager { block_id_t GetFreeBlockId() override { throw InternalException("Cannot perform IO in in-memory database - GetFreeBlockId!"); } + block_id_t GetFreeBlockIdForCheckpoint() override { + throw InternalException("Cannot perform IO in in-memory database - GetFreeBlockIdForCheckpoint!"); + } block_id_t PeekFreeBlockId() override { throw InternalException("Cannot perform IO in in-memory database - PeekFreeBlockId!"); } bool IsRootBlock(MetaBlockPointer root) override { throw InternalException("Cannot perform IO in in-memory database - IsRootBlock!"); } - void MarkBlockAsFree(block_id_t block_id) override { - throw InternalException("Cannot perform IO in in-memory database - MarkBlockAsFree!"); + void MarkBlockACheckpointed(block_id_t block_id) override { + throw InternalException("Cannot perform IO in in-memory database - MarkBlockACheckpointed!"); } void MarkBlockAsUsed(block_id_t block_id) override { throw InternalException("Cannot perform IO in in-memory database - MarkBlockAsUsed!"); diff --git a/src/duckdb/src/include/duckdb/storage/index.hpp b/src/duckdb/src/include/duckdb/storage/index.hpp index 2b624c2c1..492f37e29 100644 --- a/src/duckdb/src/include/duckdb/storage/index.hpp +++ b/src/duckdb/src/include/duckdb/storage/index.hpp @@ -31,9 +31,15 @@ class Index { protected: Index(const vector &column_ids, TableIOManager &table_io_manager, AttachedDatabase &db); - //! The logical column ids of the indexed table + //! The physical column ids of the indexed columns. + //! For example, given a table with the following columns: + //! (a INT, gen AS (2 * a), b INT, c VARCHAR), an index on columns (a,c) would have physical + //! column_ids [0,2] (since the virtual column is skipped in the physical representation). + //! Also see comments in bound_index.hpp to see how these column IDs are used in the context of + //! bound/unbound expressions. + //! Note that these are the columns for this Index, not all Indexes on the table. vector column_ids; - //! Unordered set of column_ids used by the index + //! Unordered set of column_ids used by the Index unordered_set column_id_set; public: diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp index cd63a96b8..c2320c29d 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -62,10 +62,14 @@ class MetadataManager { MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager); ~MetadataManager(); + BufferManager &GetBufferManager() const { + return buffer_manager; + } + MetadataHandle AllocateHandle(); MetadataHandle Pin(const MetadataPointer &pointer); - MetadataHandle Pin(QueryContext context, const MetadataPointer &pointer); + MetadataHandle Pin(const QueryContext &context, const MetadataPointer &pointer); MetaBlockPointer GetDiskPointer(const MetadataPointer &pointer, uint32_t offset = 0); MetadataPointer FromDiskPointer(MetaBlockPointer pointer); @@ -77,6 +81,8 @@ class MetadataManager { //! Flush all blocks to disk void Flush(); + bool BlockHasBeenCleared(const MetaBlockPointer &ptr); + void MarkBlocksAsModified(); void ClearModifiedBlocks(const vector &pointers); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp index 51894886a..ce8d01b41 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp @@ -52,7 +52,7 @@ class MetadataReader : public ReadStream { MetadataManager &manager; BlockReaderType type; MetadataHandle block; - MetadataPointer next_pointer; + MetaBlockPointer next_pointer; bool has_next_block; optional_ptr> read_pointers; idx_t index; diff --git a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp index 1ded8bba6..dca2c9b1b 100644 --- a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp @@ -14,11 +14,16 @@ namespace duckdb { class PartialBlockManager; struct OptimisticWriteCollection { + ~OptimisticWriteCollection(); + shared_ptr collection; idx_t last_flushed = 0; idx_t complete_row_groups = 0; + vector> partial_block_managers; }; +enum class OptimisticWritePartialManagers { PER_COLUMN, GLOBAL }; + class OptimisticDataWriter { public: OptimisticDataWriter(ClientContext &context, DataTable &table); @@ -26,18 +31,18 @@ class OptimisticDataWriter { ~OptimisticDataWriter(); //! Creates a collection to write to - static unique_ptr CreateCollection(DataTable &storage, - const vector &insert_types); + unique_ptr + CreateCollection(DataTable &storage, const vector &insert_types, + OptimisticWritePartialManagers type = OptimisticWritePartialManagers::PER_COLUMN); //! Write a new row group to disk (if possible) void WriteNewRowGroup(OptimisticWriteCollection &row_groups); //! Write the last row group of a collection to disk void WriteLastRowGroup(OptimisticWriteCollection &row_groups); //! Final flush of the optimistic writer - fully flushes the partial block manager void FinalFlush(); - //! Flushes a specific row group to disk - void FlushToDisk(const vector> &row_groups); //! Merge the partially written blocks from one optimistic writer into another void Merge(OptimisticDataWriter &other); + void Merge(unique_ptr &other_manager); //! Rollback void Rollback(); @@ -49,6 +54,9 @@ class OptimisticDataWriter { private: //! Prepare a write to disk bool PrepareWrite(); + //! Flushes a specific row group to disk + void FlushToDisk(OptimisticWriteCollection &collection, const vector> &row_groups, + const vector &segment_indexes); private: //! The client context in which we're writing the data. diff --git a/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp index 4e901fb0f..9b1399169 100644 --- a/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp @@ -129,6 +129,7 @@ class PartialBlockManager { unique_lock GetLock() { return unique_lock(partial_block_lock); } + block_id_t GetFreeBlockId(); //! Returns a reference to the underlying block manager. BlockManager &GetBlockManager() const; diff --git a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp index ae7d74fd3..45374d899 100644 --- a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp @@ -22,6 +22,7 @@ namespace duckdb { class DatabaseInstance; struct MetadataHandle; +enum class FreeBlockType { NEWLY_USED_BLOCK, CHECKPOINTED_BLOCK }; struct EncryptionOptions { //! indicates whether the db is encrypted @@ -73,18 +74,22 @@ class SingleFileBlockManager : public BlockManager { unique_ptr CreateBlock(block_id_t block_id, FileBuffer *source_buffer) override; //! Return the next free block id block_id_t GetFreeBlockId() override; + //! Return the next free block id + block_id_t GetFreeBlockIdForCheckpoint() override; //! Check the next free block id - but do not assign or allocate it block_id_t PeekFreeBlockId() override; //! Returns whether or not a specified block is the root block bool IsRootBlock(MetaBlockPointer root) override; - //! Mark a block as free (immediately re-writeable) - void MarkBlockAsFree(block_id_t block_id) override; + //! Mark a block as included in a checkpoint + void MarkBlockACheckpointed(block_id_t block_id) override; //! Mark a block as used (no longer re-writeable) void MarkBlockAsUsed(block_id_t block_id) override; //! Mark a block as modified (re-writeable after a checkpoint) void MarkBlockAsModified(block_id_t block_id) override; //! Increase the reference count of a block. The block should hold at least one reference void IncreaseBlockReferenceCount(block_id_t block_id) override; + //! UnregisterBlock, only accepts non-temporary block ids + void UnregisterBlock(block_id_t id) override; //! Return the meta block id idx_t GetMetaBlock() override; //! Read the content of the block from disk @@ -158,7 +163,8 @@ class SingleFileBlockManager : public BlockManager { //! Return the blocks to which we will write the free list and modified blocks vector GetFreeListBlocks(); - void TrimFreeBlocks(); + void TrimFreeBlocks(const set &blocks); + void TrimFreeBlockRange(block_id_t start, block_id_t end); void IncreaseBlockReferenceCountInternal(block_id_t block_id); @@ -167,6 +173,8 @@ class SingleFileBlockManager : public BlockManager { void AddStorageVersionTag(); + block_id_t GetFreeBlockIdInternal(FreeBlockType type); + private: AttachedDatabase &db; //! The active DatabaseHeader, either 0 (h1) or 1 (h2) @@ -179,13 +187,15 @@ class SingleFileBlockManager : public BlockManager { FileBuffer header_buffer; //! The list of free blocks that can be written to currently set free_list; - //! The list of blocks that were freed since the last checkpoint. - set newly_freed_list; + //! The list of blocks that have been freed, but cannot yet be re-used because they are still in-use + set free_blocks_in_use; + //! The list of blocks that are in-use, but haven't been written as part of a checkpoint yet + set newly_used_blocks; //! The list of multi-use blocks (i.e. blocks that have >1 reference in the file) //! When a multi-use block is marked as modified, the reference count is decreased by 1 instead of directly //! Appending the block to the modified_blocks list unordered_map multi_use_blocks; - //! The list of blocks that will be added to the free list + //! The list of blocks that are no longer in-use, but cannot be re-used until the next checkpoint unordered_set modified_blocks; //! The current meta block id idx_t meta_block; diff --git a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp index d0a54c597..ff4ce4684 100644 --- a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -71,7 +71,7 @@ class StandardBufferManager : public BufferManager { void ReAllocate(shared_ptr &handle, idx_t block_size) final; BufferHandle Pin(shared_ptr &handle) final; - BufferHandle Pin(QueryContext context, shared_ptr &handle); + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) final; void Prefetch(vector> &handles) final; void Unpin(shared_ptr &handle) final; @@ -84,6 +84,8 @@ class StandardBufferManager : public BufferManager { //! Returns information about memory usage vector GetMemoryUsageInfo() const override; + BlockManager &GetTemporaryBlockManager() final; + //! Returns a list of all temporary files vector GetTemporaryFiles() final; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp index 2101bcb31..705e462fe 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp @@ -15,6 +15,8 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" #include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" namespace duckdb { struct SelectionVector; @@ -33,7 +35,16 @@ enum class StatsInfo : uint8_t { CAN_HAVE_NULL_AND_VALID_VALUES = 4 }; -enum class StatisticsType : uint8_t { NUMERIC_STATS, STRING_STATS, LIST_STATS, STRUCT_STATS, BASE_STATS, ARRAY_STATS }; +enum class StatisticsType : uint8_t { + NUMERIC_STATS, + STRING_STATS, + LIST_STATS, + STRUCT_STATS, + BASE_STATS, + ARRAY_STATS, + GEOMETRY_STATS, + VARIANT_STATS +}; class BaseStatistics { friend struct NumericStats; @@ -41,6 +52,8 @@ class BaseStatistics { friend struct StructStats; friend struct ListStats; friend struct ArrayStats; + friend struct GeometryStats; + friend struct VariantStats; public: DUCKDB_API ~BaseStatistics(); @@ -75,7 +88,7 @@ class BaseStatistics { } void Set(StatsInfo info); - void CombineValidity(BaseStatistics &left, BaseStatistics &right); + void CombineValidity(const BaseStatistics &left, const BaseStatistics &right); void CopyValidity(BaseStatistics &stats); //! Set that the CURRENT level can have null values //! Note that this is not correct for nested types unless this information is propagated in a different manner @@ -133,7 +146,7 @@ class BaseStatistics { private: //! The type of the logical segment - LogicalType type; + LogicalType type = LogicalType::INVALID; //! Whether or not the segment can contain NULL values bool has_null; //! Whether or not the segment can contain values that are not null @@ -146,6 +159,10 @@ class BaseStatistics { NumericStatsData numeric_data; //! String stats data, for string stats StringStatsData string_data; + //! Geometry stats data, for geometry stats + GeometryStatsData geometry_data; + //! Variant stats data, for variant stats + VariantStatsData variant_data; } stats_union; //! Child stats (for LIST and STRUCT) unsafe_unique_array child_stats; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp new file mode 100644 index 000000000..6c6cfa35a --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/geometry_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/array_ptr.hpp" +#include "duckdb/common/types/geometry.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; + +class GeometryTypeSet { +public: + static constexpr auto VERT_TYPES = 4; + static constexpr auto PART_TYPES = 8; + + static GeometryTypeSet Unknown() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0xFF; + } + return result; + } + static GeometryTypeSet Empty() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0; + } + return result; + } + + bool IsEmpty() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0) { + return false; + } + } + return true; + } + + bool IsUnknown() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0xFF) { + return false; + } + } + return true; + } + + void Add(GeometryType geom_type, VertexType vert_type) { + const auto vert_idx = static_cast(vert_type); + const auto geom_idx = static_cast(geom_type); + D_ASSERT(vert_idx < VERT_TYPES); + D_ASSERT(geom_idx < PART_TYPES); + sets[vert_idx] |= (1 << geom_idx); + } + + void Merge(const GeometryTypeSet &other) { + for (idx_t i = 0; i < VERT_TYPES; i++) { + sets[i] |= other.sets[i]; + } + } + + void Clear() { + for (idx_t i = 0; i < VERT_TYPES; i++) { + sets[i] = 0; + } + } + + void AddWKBType(int32_t wkb_type) { + const auto vert_idx = static_cast((wkb_type / 1000) % 10); + const auto geom_idx = static_cast(wkb_type % 1000); + D_ASSERT(vert_idx < VERT_TYPES); + D_ASSERT(geom_idx < PART_TYPES); + sets[vert_idx] |= (1 << geom_idx); + } + + vector ToWKBList() const { + vector result; + for (uint8_t vert_idx = 0; vert_idx < VERT_TYPES; vert_idx++) { + for (uint8_t geom_idx = 1; geom_idx < PART_TYPES; geom_idx++) { + if (sets[vert_idx] & (1 << geom_idx)) { + result.push_back(geom_idx + vert_idx * 1000); + } + } + } + return result; + } + + vector ToString(bool snake_case) const; + + uint8_t sets[VERT_TYPES]; +}; + +struct GeometryStatsData { + GeometryTypeSet types; + GeometryExtent extent; + + void SetEmpty() { + types = GeometryTypeSet::Empty(); + extent = GeometryExtent::Empty(); + } + + void SetUnknown() { + types = GeometryTypeSet::Unknown(); + extent = GeometryExtent::Unknown(); + } + + void Merge(const GeometryStatsData &other) { + types.Merge(other.types); + extent.Merge(other.extent); + } + + void Update(const string_t &geom_blob) { + // Parse type + const auto type_info = Geometry::GetType(geom_blob); + types.Add(type_info.first, type_info.second); + + // Update extent + Geometry::GetExtent(geom_blob, extent); + } +}; + +struct GeometryStats { + //! Unknown statistics + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + //! Empty statistics + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + + //! Check if a spatial predicate check with a constant could possibly be satisfied by rows given the statistics + DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, + const unique_ptr &expr); + + DUCKDB_API static GeometryExtent &GetExtent(BaseStatistics &stats); + DUCKDB_API static const GeometryExtent &GetExtent(const BaseStatistics &stats); + DUCKDB_API static GeometryTypeSet &GetTypes(BaseStatistics &stats); + DUCKDB_API static const GeometryTypeSet &GetTypes(const BaseStatistics &stats); + +private: + static GeometryStatsData &GetDataUnsafe(BaseStatistics &stats); + static const GeometryStatsData &GetDataUnsafe(const BaseStatistics &stats); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp index 0982f8905..6e5814a36 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp @@ -71,6 +71,8 @@ struct StringStats { ExpressionType comparison_type, const string &value); DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMin(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMax(BaseStatistics &stats, const string_t &value); DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); diff --git a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp new file mode 100644 index 000000000..e02072b50 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types/selection_vector.hpp" + +namespace duckdb { +class BaseStatistics; + +enum class VariantStatsShreddingState : uint8_t { + //! Uninitialized, not unshredded/shredded + UNINITIALIZED, + //! No shredding applied yet + NOT_SHREDDED, + //! Shredded consistently + SHREDDED, + //! Result from combining incompatible shreddings + INCONSISTENT +}; + +struct VariantStatsData { + //! Whether the VARIANT is stored in shredded form + VariantStatsShreddingState shredding_state; +}; + +struct VariantShreddedStats { +public: + DUCKDB_API static bool IsFullyShredded(const BaseStatistics &stats); +}; + +//! VARIANT as a type can hold arbitrarily typed values within the same column. +//! In storage, we apply shredding to the VARIANT column, this means that we find the most common type among all these +//! values. And for those values we store them separately from the rest of the values, in a structured way (like you +//! would store any other column). +struct VariantStats { +public: + DUCKDB_API static void Construct(BaseStatistics &stats); + +public: + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + DUCKDB_API static BaseStatistics CreateShredded(const LogicalType &shredded_type); + +public: + //! Stats related to the 'unshredded' column, which holds all data that doesn't fit in the structure of the shredded + //! column (if IsShredded()) + DUCKDB_API static void CreateUnshreddedStats(BaseStatistics &stats); + DUCKDB_API static const BaseStatistics &GetUnshreddedStats(const BaseStatistics &stats); + DUCKDB_API static BaseStatistics &GetUnshreddedStats(BaseStatistics &stats); + + DUCKDB_API static void SetUnshreddedStats(BaseStatistics &stats, unique_ptr new_stats); + DUCKDB_API static void SetUnshreddedStats(BaseStatistics &stats, const BaseStatistics &new_stats); + DUCKDB_API static void MarkAsNotShredded(BaseStatistics &stats); + +public: + //! Stats related to the 'shredded' column, which holds all structured data created during shredding + //! Returns the LogicalType that represents the shredding as a single DuckDB LogicalType (i.e STRUCT(col1 VARCHAR)) + DUCKDB_API LogicalType GetShreddedStructuredType(const BaseStatistics &stats); + DUCKDB_API static void CreateShreddedStats(BaseStatistics &stats, const LogicalType &shredded_type); + DUCKDB_API static bool IsShredded(const BaseStatistics &stats); + DUCKDB_API static const BaseStatistics &GetShreddedStats(const BaseStatistics &stats); + DUCKDB_API static BaseStatistics &GetShreddedStats(BaseStatistics &stats); + + DUCKDB_API static void SetShreddedStats(BaseStatistics &stats, unique_ptr new_stats); + DUCKDB_API static void SetShreddedStats(BaseStatistics &stats, const BaseStatistics &new_stats); + + DUCKDB_API static bool MergeShredding(BaseStatistics &stats, const BaseStatistics &other, + BaseStatistics &new_stats); + +public: + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); + +private: + static VariantStatsData &GetDataUnsafe(BaseStatistics &stats); + static const VariantStatsData &GetDataUnsafe(const BaseStatistics &stats); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_info.hpp b/src/duckdb/src/include/duckdb/storage/storage_info.hpp index 804a76403..f919bac33 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_info.hpp @@ -30,7 +30,7 @@ class QueryContext; #define DEFAULT_BLOCK_ALLOC_SIZE 262144ULL //! The default block header size. #define DEFAULT_BLOCK_HEADER_STORAGE_SIZE 8ULL -//! The default block header size. +//! The default block header size for encrypted blocks. #define DEFAULT_ENCRYPTION_BLOCK_HEADER_SIZE 40ULL //! The configurable block allocation size. #ifndef DUCKDB_BLOCK_HEADER_STORAGE_SIZE @@ -56,8 +56,6 @@ struct Storage { constexpr static idx_t DEFAULT_BLOCK_HEADER_SIZE = sizeof(idx_t); //! The default block header size for blocks written to storage. constexpr static idx_t MAX_BLOCK_HEADER_SIZE = 128ULL; - //! Block header size for encrypted blocks (64 bytes) - constexpr static idx_t ENCRYPTED_BLOCK_HEADER_SIZE = 64ULL; //! The default block size. constexpr static idx_t DEFAULT_BLOCK_SIZE = DEFAULT_BLOCK_ALLOC_SIZE - DEFAULT_BLOCK_HEADER_SIZE; @@ -180,20 +178,20 @@ class MainHeader { //! DatabaseHeader. struct DatabaseHeader { //! The iteration count, increases by 1 every time the storage is checkpointed. - uint64_t iteration; + uint64_t iteration = 0; //! A pointer to the initial meta block - idx_t meta_block; + idx_t meta_block = 0; //! A pointer to the block containing the free list - idx_t free_list; + idx_t free_list = 0; //! The number of blocks that is in the file as of this database header. If the file is larger than BLOCK_SIZE * //! block_count any blocks appearing AFTER block_count are implicitly part of the free_list. - uint64_t block_count; + uint64_t block_count = 0; //! The allocation size of blocks in this database file. Defaults to default_block_alloc_size (DBConfig). - idx_t block_alloc_size; + idx_t block_alloc_size = 0; //! The vector size of the database file - idx_t vector_size; + idx_t vector_size = 0; //! The serialization compatibility version - idx_t serialization_compatibility; + idx_t serialization_compatibility = 0; void Write(WriteStream &ser); static DatabaseHeader Read(const MainHeader &header, ReadStream &source); diff --git a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp index c96a76ff7..b7fa7ccec 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp @@ -14,7 +14,7 @@ #include "duckdb/storage/table_io_manager.hpp" #include "duckdb/storage/write_ahead_log.hpp" #include "duckdb/storage/database_size.hpp" -#include "duckdb/common/enums/checkpoint_type.hpp" +#include "duckdb/storage/checkpoint/checkpoint_options.hpp" #include "duckdb/storage/storage_options.hpp" namespace duckdb { @@ -47,17 +47,6 @@ class StorageCommitState { } }; -struct CheckpointOptions { - CheckpointOptions() - : wal_action(CheckpointWALAction::DONT_DELETE_WAL), action(CheckpointAction::CHECKPOINT_IF_REQUIRED), - type(CheckpointType::FULL_CHECKPOINT) { - } - - CheckpointWALAction wal_action; - CheckpointAction action; - CheckpointType type; -}; - //! StorageManager is responsible for managing the physical storage of a persistent database. class StorageManager { public: @@ -80,10 +69,17 @@ class StorageManager { //! Gets the size of the WAL, or zero, if there is no WAL. idx_t GetWALSize(); + bool HasWAL() const; + void AddWALSize(idx_t size); + void SetWALSize(idx_t size); //! Gets the WAL of the StorageManager, or nullptr, if there is no WAL. optional_ptr GetWAL(); - //! Deletes the WAL file, and resets the unique pointer. - void ResetWAL(); + //! Write that we started a checkpoint to the WAL if there is one - returns whether or not there is a WAL + bool WALStartCheckpoint(MetaBlockPointer meta_block, CheckpointOptions &options); + //! Finishes a checkpoint + void WALFinishCheckpoint(); + // Get the WAL lock + unique_ptr> GetWALLock(); //! Returns the database file path string GetDBPath() const { @@ -93,7 +89,11 @@ class StorageManager { return load_complete; } //! The path to the WAL, derived from the database file path - string GetWALPath() const; + string GetWALPath(const string &suffix = ".wal"); + //! The path to the WAL that is used while a checkpoint is running + string GetCheckpointWALPath(); + //! The path to the WAL that is used while recovering from a crash involving the checkpoint WAL + string GetRecoveryWALPath(); bool InMemory() const; virtual bool AutomaticCheckpoint(idx_t estimated_wal_bytes) = 0; @@ -116,12 +116,6 @@ class StorageManager { D_ASSERT(HasStorageVersion()); return storage_version.GetIndex(); } - void AddInMemoryChange(idx_t size) { - in_memory_change_size += size; - } - void ResetInMemoryChange() { - in_memory_change_size = 0; - } bool CompressionIsEnabled() const { return storage_options.compress_in_memory == CompressInMemory::COMPRESS; } @@ -147,8 +141,12 @@ class StorageManager { AttachedDatabase &db; //! The path of the database string path; + //! The WAL path + string wal_path; //! The WriteAheadLog of the storage manager unique_ptr wal; + //! Mutex used to control writes to the WAL + mutex wal_lock; //! Whether or not the database is opened in read-only mode bool read_only; //! When loading a database, we do not yet set the wal-field. Therefore, GetWriteAheadLog must @@ -156,8 +154,9 @@ class StorageManager { bool load_complete = false; //! The serialization compatibility version when reading and writing from this database optional_idx storage_version; - //! Estimated size of changes for determining automatic checkpointing on in-memory databases - atomic in_memory_change_size; + //! Estimated size of changes for determining automatic checkpointing on in-memory databases and databases without a + //! WAL. + atomic wal_size; //! Storage options passed in through configuration StorageOptions storage_options; diff --git a/src/duckdb/src/include/duckdb/storage/storage_options.hpp b/src/duckdb/src/include/duckdb/storage/storage_options.hpp index 786924b2c..4cf1f539b 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_options.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_options.hpp @@ -40,11 +40,4 @@ struct StorageOptions { void Initialize(const unordered_map &options); }; -inline void ClearUserKey(shared_ptr const &encryption_key) { - if (encryption_key && !encryption_key->empty()) { - memset(&(*encryption_key)[0], 0, encryption_key->size()); - encryption_key->clear(); - } -} - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp index b5342829c..b59f327b8 100644 --- a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -67,7 +67,7 @@ struct UncompressedStringStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); static idx_t StringFinalAnalyze(AnalyzeState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); @@ -118,7 +118,9 @@ struct UncompressedStringStorage { return i; } remaining_space -= sizeof(int32_t); - if (!data.validity.RowIsValid(source_idx)) { + const bool is_null = !data.validity.RowIsValid(source_idx); + if (is_null) { + stats.statistics.SetHasNullFast(); // null value is stored as a copy of the last value, this is done to be able to efficiently do the // string_length calculation if (target_idx > 0) { @@ -201,7 +203,12 @@ struct UncompressedStringStorage { public: static inline void UpdateStringStats(SegmentStatistics &stats, const string_t &new_value) { - StringStats::Update(stats.statistics, new_value); + stats.statistics.SetHasNoNullFast(); + if (stats.statistics.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + GeometryStats::Update(stats.statistics, new_value); + } else { + StringStats::Update(stats.statistics, new_value); + } } static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer dict); @@ -239,6 +246,6 @@ struct UncompressedStringStorage { static unique_ptr SerializeState(ColumnSegment &segment); static unique_ptr DeserializeState(Deserializer &deserializer); - static void CleanupState(ColumnSegment &segment); + static void VisitBlockIds(const ColumnSegment &segment, BlockIdVisitor &visitor); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp index 0a5c7b170..f63d40633 100644 --- a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp @@ -24,12 +24,16 @@ class LocalTableStorage; class RowGroup; class UpdateSegment; class TableCatalogEntry; +template +struct SegmentNode; +class RowGroupSegmentTree; +class CheckpointLock; struct TableAppendState; struct ColumnAppendState { //! The current segment of the append - ColumnSegment *current; + optional_ptr> current; //! Child append states vector child_appends; //! The write lock that is held by the append @@ -62,12 +66,16 @@ struct TableAppendState { RowGroupAppendState row_group_append_state; unique_lock append_lock; + shared_ptr table_lock; row_t row_start; row_t current_row; //! The total number of rows appended by the append operation idx_t total_append_count; + idx_t row_group_start; + //! The row group segment tree we are appending to + shared_ptr row_groups; //! The first row-group that has been appended to - RowGroup *start_row_group; + optional_ptr> start_row_group; //! The transaction data TransactionData transaction; //! Table statistics diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index abc9577a3..e66376978 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -16,16 +16,11 @@ namespace duckdb { //! List column data represents a list class ArrayColumnData : public ColumnData { public: - ArrayColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type, optional_ptr parent = nullptr); - - //! The child-column of the list - unique_ptr child_column; - //! The validity column data of the array - ValidityColumnData validity; + ArrayColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); public: - void SetStart(idx_t new_start) override; + void SetDataType(ColumnDataType data_type) override; FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; @@ -44,31 +39,41 @@ class ArrayColumnData : public ColumnData { void InitializeAppend(ColumnAppendState &state) override; void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; - void RevertAppend(row_t start_row) override; + void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; unique_ptr GetUpdateStatistics() override; - void CommitDropColumn() override; + void VisitBlockIds(BlockIdVisitor &visitor) const override; - unique_ptr CreateCheckpointState(RowGroup &row_group, + unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; - unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; + + void SetValidityData(shared_ptr validity); + void SetChildData(shared_ptr child_column); + +protected: + //! The child-column of the list + shared_ptr child_column; + //! The validity column data of the array + shared_ptr validity; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp index 44b92dd74..db959f4cd 100644 --- a/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/chunk_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/vector_size.hpp" #include "duckdb/common/atomic.hpp" +#include "duckdb/execution/index/index_pointer.hpp" namespace duckdb { class RowGroup; @@ -20,6 +21,7 @@ struct TransactionData; struct DeleteInfo; class Serializer; class Deserializer; +class FixedSizeAllocator; enum class ChunkInfoType : uint8_t { CONSTANT_INFO, VECTOR_INFO, EMPTY_INFO }; @@ -38,19 +40,21 @@ class ChunkInfo { public: //! Gets up to max_count entries from the chunk info. If the ret is 0>ret>max_count, the selection vector is filled //! with the tuples - virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) = 0; + virtual idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const = 0; virtual idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) = 0; + virtual idx_t GetCheckpointRowCount(TransactionData transaction, idx_t max_count) = 0; //! Returns whether or not a single row in the ChunkInfo should be used or not for the given transaction virtual bool Fetch(TransactionData transaction, row_t row) = 0; virtual void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) = 0; - virtual idx_t GetCommittedDeletedCount(idx_t max_count) = 0; - virtual bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const; + virtual idx_t GetCommittedDeletedCount(idx_t max_count) const = 0; + virtual bool Cleanup(transaction_t lowest_transaction) const; + virtual string ToString(idx_t max_count) const = 0; virtual bool HasDeletes() const = 0; virtual void Write(WriteStream &writer) const; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); public: template @@ -81,13 +85,15 @@ class ChunkConstantInfo : public ChunkInfo { transaction_t delete_id; public: - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetCheckpointRowCount(TransactionData transaction, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; - bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; + bool Cleanup(transaction_t lowest_transaction) const override; + string ToString(idx_t max_count) const override; bool HasDeletes() const override; @@ -105,27 +111,21 @@ class ChunkVectorInfo : public ChunkInfo { static constexpr const ChunkInfoType TYPE = ChunkInfoType::VECTOR_INFO; public: - explicit ChunkVectorInfo(idx_t start); - - //! The transaction ids of the transactions that inserted the tuples (if any) - transaction_t inserted[STANDARD_VECTOR_SIZE]; - transaction_t insert_id; - bool same_inserted_id; - - //! The transaction ids of the transactions that deleted the tuples (if any) - transaction_t deleted[STANDARD_VECTOR_SIZE]; - bool any_deleted; + explicit ChunkVectorInfo(FixedSizeAllocator &allocator, idx_t start, transaction_t insert_id = 0); + ~ChunkVectorInfo() override; public: idx_t GetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; - idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const override; idx_t GetCommittedSelVector(transaction_t min_start_id, transaction_t min_transaction_id, SelectionVector &sel_vector, idx_t max_count) override; + idx_t GetCheckpointRowCount(TransactionData transaction, idx_t max_count) override; bool Fetch(TransactionData transaction, row_t row) override; void CommitAppend(transaction_t commit_id, idx_t start, idx_t end) override; - bool Cleanup(transaction_t lowest_transaction, unique_ptr &result) const override; - idx_t GetCommittedDeletedCount(idx_t max_count) override; + bool Cleanup(transaction_t lowest_transaction) const override; + idx_t GetCommittedDeletedCount(idx_t max_count) const override; + string ToString(idx_t max_count) const override; void Append(idx_t start, idx_t end, transaction_t commit_id); @@ -138,14 +138,32 @@ class ChunkVectorInfo : public ChunkInfo { void CommitDelete(transaction_t commit_id, const DeleteInfo &info); bool HasDeletes() const override; + bool AnyDeleted() const; + bool HasConstantInsertionId() const; + transaction_t ConstantInsertId() const; void Write(WriteStream &writer) const override; - static unique_ptr Read(ReadStream &reader); + static unique_ptr Read(FixedSizeAllocator &allocator, ReadStream &reader); private: template idx_t TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const; + + IndexPointer GetInsertedPointer() const; + IndexPointer GetDeletedPointer() const; + IndexPointer GetInitializedInsertedPointer(); + IndexPointer GetInitializedDeletedPointer(); + +private: + FixedSizeAllocator &allocator; + //! The transaction ids of the transactions that inserted the tuples (if any) + IndexPointer inserted_data; + //! The constant insert id (if there is only one) + transaction_t constant_insert_id; + + //! The transaction ids of the transactions that deleted the tuples (if any) + IndexPointer deleted_data; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp index 9e5f7f98b..5537c8814 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp @@ -25,19 +25,27 @@ class PartialBlockManager; class TableDataWriter; struct ColumnCheckpointState { - ColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager); + ColumnCheckpointState(const RowGroup &row_group, ColumnData &original_column, + PartialBlockManager &partial_block_manager); virtual ~ColumnCheckpointState(); - RowGroup &row_group; - ColumnData &column_data; - ColumnSegmentTree new_tree; + const RowGroup &row_group; + const ColumnData &original_column; vector data_pointers; unique_ptr global_stats; protected: PartialBlockManager &partial_block_manager; + shared_ptr result_column; + +private: + ColumnData &original_column_mutable; public: + virtual shared_ptr CreateEmptyColumnData(); + virtual ColumnData &GetResultColumn(); + virtual shared_ptr GetFinalResult(); + virtual unique_ptr GetStatistics(); virtual void FlushSegmentInternal(unique_ptr segment, idx_t segment_size); diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index 400daeaa6..300ab2ffc 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -35,30 +35,33 @@ struct RowGroupWriteInfo; struct TableScanOptions; struct TransactionData; struct PersistentColumnData; +class ValidityColumnData; using column_segment_vector_t = vector>; struct ColumnCheckpointInfo { - ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) : info(info), column_idx(column_idx) { - } + ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx); - RowGroupWriteInfo &info; idx_t column_idx; public: + PartialBlockManager &GetPartialBlockManager(); CompressionType GetCompressionType(); + +private: + RowGroupWriteInfo &info; }; -class ColumnData { +enum class ColumnDataType { MAIN_TABLE, INITIAL_TRANSACTION_LOCAL, TRANSACTION_LOCAL, CHECKPOINT_TARGET }; + +class ColumnData : public enable_shared_from_this { friend class ColumnDataCheckpointer; public: - ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, LogicalType type, - optional_ptr parent); + ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); virtual ~ColumnData(); - //! The start row - idx_t start; //! The count of the column data atomic count; //! The block manager @@ -73,7 +76,7 @@ class ColumnData { public: virtual FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter); - BlockManager &GetBlockManager() { + BlockManager &GetBlockManager() const { return block_manager; } DatabaseInstance &GetDatabase() const; @@ -87,18 +90,31 @@ class ColumnData { optional_ptr GetCompressionFunction() const { return compression.get(); } + virtual void SetDataType(ColumnDataType data_type); + ColumnDataType GetDataType() const { + return data_type; + } bool HasParent() const { - return parent != nullptr; + return parent; + } + void SetParent(optional_ptr parent) { + this->parent = parent; } const ColumnData &Parent() const { D_ASSERT(HasParent()); return *parent; } + const LogicalType &GetType() const { + return type; + } + ColumnSegmentTree &GetSegmentTree() { + return data; + } + void SetCount(idx_t new_count) { + this->count = new_count; + } - virtual void SetStart(idx_t new_start); - //! The root type of the column - const LogicalType &RootType() const; //! Whether or not the column has any updates bool HasUpdates() const; bool HasChanges(idx_t start_row, idx_t end_row) const; @@ -146,7 +162,7 @@ class ColumnData { void Append(ColumnAppendState &state, Vector &vector, idx_t count); virtual void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count); //! Revert a set of appends to the ColumnData - virtual void RevertAppend(row_t start_row); + virtual void RevertAppend(row_t new_count); //! Fetch the vector from the column data that belongs to this specific row virtual idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result); @@ -154,20 +170,20 @@ class ColumnData { virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx); - virtual void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count); - virtual void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth); + virtual void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start); + virtual void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start); virtual unique_ptr GetUpdateStatistics(); - virtual void CommitDropColumn(); + virtual void VisitBlockIds(BlockIdVisitor &visitor) const; - virtual unique_ptr CreateCheckpointState(RowGroup &row_group, + virtual unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager); - virtual unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info); + virtual unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info); - virtual void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector); + virtual void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, Vector &scan_vector) const; virtual bool IsPersistent(); vector GetDataPointers(); @@ -176,23 +192,22 @@ class ColumnData { void InitializeColumn(PersistentColumnData &column_data); virtual void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats); static shared_ptr Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, ReadStream &source, const LogicalType &type); + ReadStream &source, const LogicalType &type); - virtual void GetColumnSegmentInfo(idx_t row_group_index, vector col_path, vector &result); + virtual void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result); virtual void Verify(RowGroup &parent); FilterPropagateResult CheckZonemap(TableFilter &filter); static shared_ptr CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, const LogicalType &type, + const LogicalType &type, + ColumnDataType data_type = ColumnDataType::MAIN_TABLE, optional_ptr parent = nullptr); - static unique_ptr CreateColumnUnique(BlockManager &block_manager, DataTableInfo &info, - idx_t column_index, idx_t start_row, const LogicalType &type, - optional_ptr parent = nullptr); void MergeStatistics(const BaseStatistics &other); void MergeIntoStatistics(BaseStatistics &other); - unique_ptr GetStatistics(); + unique_ptr GetStatistics() const; protected: //! Append a transient segment @@ -213,13 +228,12 @@ class ColumnData { void FilterVector(ColumnScanState &state, Vector &result, idx_t target_count, SelectionVector &sel, idx_t &sel_count, const TableFilter &filter, TableFilterState &filter_state); - void ClearUpdates(); void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result, idx_t scan_count, bool allow_updates, bool scan_committed); void FetchUpdateRow(TransactionData transaction, row_t row_id, Vector &result, idx_t result_idx); - void UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector); - idx_t FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector); + void UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, Vector &base_vector, idx_t row_group_start); + idx_t FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector, idx_t row_group_start); idx_t GetVectorCount(idx_t vector_index) const; @@ -241,16 +255,31 @@ class ColumnData { atomic allocation_size; private: + //! Whether or not this column data belongs to a main table or if it is transaction local + atomic data_type; //! The parent column (if any) optional_ptr parent; //! The compression function used by the ColumnData //! This is empty if the segments have mixed compression or the ColumnData is empty atomic_ptr compression; + +public: + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } }; struct PersistentColumnData { - explicit PersistentColumnData(PhysicalType physical_type); - PersistentColumnData(PhysicalType physical_type, vector pointers); +public: + explicit PersistentColumnData(const LogicalType &logical_type); + PersistentColumnData(const LogicalType &logical_type, vector pointers); // disable copy constructors PersistentColumnData(const PersistentColumnData &other) = delete; PersistentColumnData &operator=(const PersistentColumnData &) = delete; @@ -259,16 +288,21 @@ struct PersistentColumnData { PersistentColumnData &operator=(PersistentColumnData &&) = default; ~PersistentColumnData(); - PhysicalType physical_type; - vector pointers; - vector child_columns; - bool has_updates = false; - +public: void Serialize(Serializer &serializer) const; static PersistentColumnData Deserialize(Deserializer &deserializer); void DeserializeField(Deserializer &deserializer, field_id_t field_idx, const char *field_name, const LogicalType &type); bool HasUpdates() const; + void SetVariantShreddedType(const LogicalType &shredded_type); + +public: + PhysicalType physical_type; + LogicalTypeId logical_type_id; + vector pointers; + vector child_columns; + bool has_updates = false; + LogicalType variant_shredded_type; }; struct PersistentRowGroupData { diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp index f31e77521..eb6efb105 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp @@ -22,17 +22,16 @@ struct ColumnDataCheckpointData { ColumnDataCheckpointData() { } ColumnDataCheckpointData(ColumnCheckpointState &checkpoint_state, ColumnData &col_data, DatabaseInstance &db, - RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info, - StorageManager &storage_manager) + const RowGroup &row_group, StorageManager &storage_manager) : checkpoint_state(checkpoint_state), col_data(col_data), db(db), row_group(row_group), - checkpoint_info(checkpoint_info), storage_manager(storage_manager) { + storage_manager(storage_manager) { } public: CompressionFunction &GetCompressionFunction(CompressionType type); const LogicalType &GetType() const; ColumnData &GetColumnData(); - RowGroup &GetRowGroup(); + const RowGroup &GetRowGroup(); ColumnCheckpointState &GetCheckpointState(); DatabaseInstance &GetDatabase(); StorageManager &GetStorageManager(); @@ -41,8 +40,7 @@ struct ColumnDataCheckpointData { optional_ptr checkpoint_state; optional_ptr col_data; optional_ptr db; - optional_ptr row_group; - optional_ptr checkpoint_info; + optional_ptr row_group; optional_ptr storage_manager; }; @@ -63,7 +61,7 @@ struct CheckpointAnalyzeResult { class ColumnDataCheckpointer { public: ColumnDataCheckpointer(vector> &states, StorageManager &storage_manager, - RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info); + const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info); public: void Checkpoint(); @@ -73,8 +71,8 @@ class ColumnDataCheckpointer { void ScanSegments(const std::function &callback); vector DetectBestCompressionMethod(); void WriteToDisk(); - bool HasChanges(ColumnData &col_data); void WritePersistentSegments(ColumnCheckpointState &state); + bool HasChanges(ColumnData &col_data); void InitAnalyze(); void DropSegments(); bool ValidityCoveredByBasedata(vector &result); @@ -82,11 +80,11 @@ class ColumnDataCheckpointer { private: vector> &checkpoint_states; StorageManager &storage_manager; - RowGroup &row_group; + const RowGroup &row_group; Vector intermediate; ColumnCheckpointInfo &checkpoint_info; - vector has_changes; + bool has_changes = false; //! For every column data that is being checkpointed, the applicable functions vector>> compression_functions; //! For every column data that is being checkpointed, the analyze state of functions being tried diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp index 61b2c0d4f..ed8dfb045 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -29,7 +29,6 @@ class DatabaseInstance; class TableFilter; class Transaction; class UpdateSegment; - struct ColumnAppendState; struct ColumnFetchState; struct ColumnScanState; @@ -43,24 +42,23 @@ class ColumnSegment : public SegmentBase { public: //! Construct a column segment. ColumnSegment(DatabaseInstance &db, shared_ptr block, const LogicalType &type, - const ColumnSegmentType segment_type, const idx_t start, const idx_t count, - CompressionFunction &function_p, BaseStatistics statistics, const block_id_t block_id_p, - const idx_t offset, const idx_t segment_size_p, - unique_ptr segment_state_p = nullptr); + const ColumnSegmentType segment_type, const idx_t count, CompressionFunction &function_p, + BaseStatistics statistics, const block_id_t block_id_p, const idx_t offset, + const idx_t segment_size_p, unique_ptr segment_state_p = nullptr); //! Construct a column segment from another column segment. //! The other column segment becomes invalid (std::move). - ColumnSegment(ColumnSegment &other, const idx_t start); + ColumnSegment(ColumnSegment &other); ~ColumnSegment(); public: static unique_ptr CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, block_id_t id, idx_t offset, const LogicalType &type_p, - idx_t start, idx_t count, CompressionType compression_type, + idx_t count, CompressionType compression_type, BaseStatistics statistics, unique_ptr segment_state); static unique_ptr CreateTransientSegment(DatabaseInstance &db, CompressionFunction &function, - const LogicalType &type, const idx_t start, - const idx_t segment_size, BlockManager &block_manager); + const LogicalType &type, const idx_t segment_size, + BlockManager &block_manager); public: void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state); @@ -97,7 +95,7 @@ class ColumnSegment : public SegmentBase { //! Returns the number of bytes occupied within the segment idx_t FinalizeAppend(ColumnAppendState &state); //! Revert an append made to this segment - void RevertAppend(idx_t start_row); + void RevertAppend(idx_t new_count); //! Convert a transient in-memory segment to a persistent segment backed by an on-disk block. //! Only used during checkpointing. @@ -107,7 +105,7 @@ class ColumnSegment : public SegmentBase { void MarkAsPersistent(shared_ptr block, uint32_t offset_in_block); void SetBlock(shared_ptr block, uint32_t offset); //! Gets a data pointer from a persistent column segment - DataPointer GetDataPointer(); + DataPointer GetDataPointer(idx_t row_start); block_id_t GetBlockId() { D_ASSERT(segment_type == ColumnSegmentType::PERSISTENT); @@ -126,17 +124,11 @@ class ColumnSegment : public SegmentBase { return offset; } - idx_t GetRelativeIndex(idx_t row_index) { - D_ASSERT(row_index >= this->start); - D_ASSERT(row_index <= this->start + this->count); - return row_index - this->start; - } - - optional_ptr GetSegmentState() { + optional_ptr GetSegmentState() const { return segment_state.get(); } - void CommitDropSegment(); + void VisitBlockIds(BlockIdVisitor &visitor) const; private: void Scan(ColumnScanState &state, idx_t scan_count, Vector &result); diff --git a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp index 9511cb18a..fc009bf8c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp @@ -47,6 +47,8 @@ struct DataTableInfo { unique_ptr GetSharedLock() { return checkpoint_lock.GetSharedLock(); } + bool IsUnseenCheckpoint(transaction_t checkpoint_id); + void VerifyIndexBuffers(); string GetSchemaName(); string GetTableName(); @@ -69,6 +71,8 @@ struct DataTableInfo { vector index_storage_infos; //! Lock held while checkpointing StorageLock checkpoint_lock; + //! The last seen checkpoint while doing a concurrent operation, if any + optional_idx last_seen_checkpoint; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp index e36e6169d..f784e57a6 100644 --- a/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp @@ -17,7 +17,7 @@ namespace duckdb { class InMemoryCheckpointer final : public CheckpointWriter { public: InMemoryCheckpointer(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, - StorageManager &storage_manager, CheckpointType checkpoint_type); + StorageManager &storage_manager, CheckpointOptions options); void CreateCheckpoint() override; @@ -27,8 +27,8 @@ class InMemoryCheckpointer final : public CheckpointWriter { optional_ptr GetClientContext() const { return context; } - CheckpointType GetCheckpointType() const { - return checkpoint_type; + CheckpointOptions GetCheckpointOptions() const { + return options; } PartialBlockManager &GetPartialBlockManager() { return partial_block_manager; @@ -41,7 +41,7 @@ class InMemoryCheckpointer final : public CheckpointWriter { optional_ptr context; PartialBlockManager partial_block_manager; StorageManager &storage_manager; - CheckpointType checkpoint_type; + CheckpointOptions options; }; class InMemoryTableDataWriter : public TableDataWriter { @@ -53,7 +53,8 @@ class InMemoryTableDataWriter : public TableDataWriter { void FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) override; unique_ptr GetRowGroupWriter(RowGroup &row_group) override; - CheckpointType GetCheckpointType() const override; + void FlushPartialBlocks() override; + CheckpointOptions GetCheckpointOptions() const override; MetadataManager &GetMetadataManager() override; private: @@ -66,7 +67,7 @@ class InMemoryRowGroupWriter : public RowGroupWriter { InMemoryCheckpointer &checkpoint_manager); public: - CheckpointType GetCheckpointType() const override; + CheckpointOptions GetCheckpointOptions() const override; WriteStream &GetPayloadWriter() override; MetaBlockPointer GetMetaBlockPointer() override; optional_ptr GetMetadataManager() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index c8e75d136..1bfa67480 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -16,16 +16,11 @@ namespace duckdb { //! List column data represents a list class ListColumnData : public ColumnData { public: - ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type, optional_ptr parent = nullptr); - - //! The child-column of the list - unique_ptr child_column; - //! The validity column data of the list - ValidityColumnData validity; + ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); public: - void SetStart(idx_t new_start) override; + void SetDataType(ColumnDataType data_type) override; FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; @@ -42,29 +37,39 @@ class ListColumnData : public ColumnData { void InitializeAppend(ColumnAppendState &state) override; void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; - void RevertAppend(row_t start_row) override; + void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; unique_ptr GetUpdateStatistics() override; - void CommitDropColumn() override; + void VisitBlockIds(BlockIdVisitor &visitor) const override; - unique_ptr CreateCheckpointState(RowGroup &row_group, + unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; - unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; + + void SetValidityData(shared_ptr validity_p); + void SetChildData(shared_ptr child_column_p); + +protected: + //! The child-column of the list + shared_ptr child_column; + //! The validity column data of the list + shared_ptr validity; private: uint64_t FetchListOffset(idx_t row_idx); diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index 242e19121..37f89f847 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -19,7 +19,7 @@ #include "duckdb/storage/block.hpp" #include "duckdb/common/enums/checkpoint_type.hpp" #include "duckdb/storage/storage_index.hpp" -#include "duckdb/function/partition_stats.hpp" +#include "duckdb/storage/checkpoint/checkpoint_options.hpp" namespace duckdb { class AttachedDatabase; @@ -37,6 +37,7 @@ class TableStatistics; struct ColumnSegmentInfo; class Vector; struct ColumnCheckpointState; +struct PartitionStatistics; struct PersistentColumnData; struct PersistentRowGroupData; struct RowGroupPointer; @@ -50,22 +51,38 @@ class MetadataManager; class RowVersionManager; class ScanFilterInfo; class StorageCommitState; +template +struct SegmentNode; +enum class ColumnDataType; struct RowGroupWriteInfo { RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, - CheckpointType checkpoint_type = CheckpointType::FULL_CHECKPOINT) - : manager(manager), compression_types(compression_types), checkpoint_type(checkpoint_type) { - } + CheckpointOptions options = CheckpointOptions()); + RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p); +private: PartialBlockManager &manager; + +public: const vector &compression_types; - CheckpointType checkpoint_type; + CheckpointOptions options; + +public: + PartialBlockManager &GetPartialBlockManager(idx_t column_idx); + +private: + optional_ptr>> column_partial_block_managers; }; struct RowGroupWriteData { + shared_ptr result_row_group; vector> states; vector statistics; - vector existing_pointers; + bool reuse_existing_metadata_blocks = false; + bool should_checkpoint = true; + vector existing_extra_metadata_blocks; + optional_idx write_count; }; class RowGroup : public SegmentBase { @@ -73,7 +90,7 @@ class RowGroup : public SegmentBase { friend class ColumnData; public: - RowGroup(RowGroupCollection &collection, idx_t start, idx_t count); + RowGroup(RowGroupCollection &collection, idx_t count); RowGroup(RowGroupCollection &collection, RowGroupPointer pointer); RowGroup(RowGroupCollection &collection, PersistentRowGroupData &data); ~RowGroup(); @@ -85,26 +102,25 @@ class RowGroup : public SegmentBase { atomic> version_info; //! The owned version info of the row_group (inserted and deleted tuple info) shared_ptr owned_version_info; - //! The column data of the row_group - vector> columns; + //! The column data of the row_group (mutable because `const` can lazily load) + mutable vector> columns; public: - void MoveToCollection(RowGroupCollection &collection, idx_t new_start); - RowGroupCollection &GetCollection() { + void MoveToCollection(RowGroupCollection &collection); + RowGroupCollection &GetCollection() const { return collection.get(); } //! Returns the list of meta block pointers used by the columns - vector GetColumnPointers(); - //! Returns the list of meta block pointers used by the deletes - const vector &GetDeletesPointers() const { - return deletes_pointers; - } - BlockManager &GetBlockManager(); - DataTableInfo &GetTableInfo(); + vector GetOrComputeExtraMetadataBlocks(bool force_compute = false); + + const vector &GetColumnStartPointers() const; + + BlockManager &GetBlockManager() const; + DataTableInfo &GetTableInfo() const; unique_ptr AlterType(RowGroupCollection &collection, const LogicalType &target_type, idx_t changed_idx, ExpressionExecutor &executor, CollectionScanState &scan_state, - DataChunk &scan_chunk); + SegmentNode &node, DataChunk &scan_chunk); unique_ptr AddColumn(RowGroupCollection &collection, ColumnDefinition &new_column, ExpressionExecutor &executor, Vector &intermediate); unique_ptr RemoveColumn(RowGroupCollection &collection, idx_t removed_column); @@ -112,12 +128,12 @@ class RowGroup : public SegmentBase { void CommitDrop(); void CommitDropColumn(const idx_t column_index); - void InitializeEmpty(const vector &types); + void InitializeEmpty(const vector &types, ColumnDataType data_type); bool HasChanges() const; //! Initialize a scan over this row_group - bool InitializeScan(CollectionScanState &state); - bool InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset); + bool InitializeScan(CollectionScanState &state, SegmentNode &node); + bool InitializeScanWithOffset(CollectionScanState &state, SegmentNode &node, idx_t vector_offset); //! Checks the given set of table filters against the row-group statistics. Returns false if the entire row group //! can be skipped. bool CheckZonemap(ScanFilterInfo &filters); @@ -127,6 +143,8 @@ class RowGroup : public SegmentBase { void Scan(TransactionData transaction, CollectionScanState &state, DataChunk &result); void ScanCommitted(CollectionScanState &state, DataChunk &result, TableScanType type); + //! Whether or not this RowGroup should be + bool ShouldCheckpointRowGroup(transaction_t checkpoint_id) const; idx_t GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); idx_t GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); @@ -142,40 +160,43 @@ class RowGroup : public SegmentBase { //! Commit a previous append made by RowGroup::AppendVersionInfo void CommitAppend(transaction_t commit_id, idx_t start, idx_t count); //! Revert a previous append made by RowGroup::AppendVersionInfo - void RevertAppend(idx_t start); + void RevertAppend(idx_t new_count); //! Clean up append states that can either be compressed or deleted void CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_t count); //! Delete the given set of rows in the version manager - idx_t Delete(TransactionData transaction, DataTable &table, row_t *row_ids, idx_t count); + idx_t Delete(TransactionData transaction, DataTable &table, row_t *row_ids, idx_t count, idx_t row_group_start); static vector WriteToDisk(RowGroupWriteInfo &info, - const vector> &row_groups); - RowGroupWriteData WriteToDisk(RowGroupWriteInfo &info); + const vector> &row_groups); + //! Write the data inside this RowGroup to disk and return a struct with information about the write + //! Including the new RowGroup that contains the columns in their written-to-disk form + RowGroupWriteData WriteToDisk(RowGroupWriteInfo &info) const; //! Returns the number of committed rows (count - committed deletes) idx_t GetCommittedRowCount(); RowGroupWriteData WriteToDisk(RowGroupWriter &writer); - RowGroupPointer Checkpoint(RowGroupWriteData write_data, RowGroupWriter &writer, TableStatistics &global_stats); + RowGroupPointer Checkpoint(RowGroupWriteData write_data, RowGroupWriter &writer, TableStatistics &global_stats, + idx_t row_group_start); bool IsPersistent() const; - PersistentRowGroupData SerializeRowGroupInfo() const; + PersistentRowGroupData SerializeRowGroupInfo(idx_t row_group_start) const; void InitializeAppend(RowGroupAppendState &append_state); void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); - void Update(TransactionData transaction, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids); + void Update(TransactionData transaction, DataTable &data_table, DataChunk &updates, row_t *ids, idx_t offset, + idx_t count, const vector &column_ids, idx_t row_group_start); //! Update a single column; corresponds to DataTable::UpdateColumn //! This method should only be called from the WAL - void UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path); + void UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path, idx_t row_group_start); void MergeStatistics(idx_t column_idx, const BaseStatistics &other); void MergeIntoStatistics(idx_t column_idx, BaseStatistics &other); void MergeIntoStatistics(TableStatistics &other); - unique_ptr GetStatistics(idx_t column_idx); + unique_ptr GetStatistics(idx_t column_idx) const; - void GetColumnSegmentInfo(idx_t row_group_index, vector &result); - PartitionStatistics GetPartitionStats() const; + void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector &result); + PartitionStatistics GetPartitionStats(idx_t row_group_start); idx_t GetAllocationSize() const { return allocation_size; @@ -195,38 +216,43 @@ class RowGroup : public SegmentBase { idx_t GetRowGroupSize() const; static FilterPropagateResult CheckRowIdFilter(const TableFilter &filter, idx_t beg_row, idx_t end_row); + idx_t GetColumnCount() const; + + vector CheckpointDeletes(MetadataManager &manager); private: optional_ptr GetVersionInfo(); + optional_ptr GetVersionInfoIfLoaded() const; shared_ptr GetOrCreateVersionInfoPtr(); shared_ptr GetOrCreateVersionInfoInternal(); void SetVersionInfo(shared_ptr version); - ColumnData &GetColumn(storage_t c); - ColumnData &GetColumn(const StorageIndex &c); - idx_t GetColumnCount() const; + ColumnData &GetColumn(storage_t c) const; + void LoadColumn(storage_t c) const; + ColumnData &GetColumn(const StorageIndex &c) const; vector> &GetColumns(); - ColumnData &GetRowIdColumnData(); + void LoadRowIdColumnData() const; void SetCount(idx_t count); template void TemplatedScan(TransactionData transaction, CollectionScanState &state, DataChunk &result); - vector CheckpointDeletes(MetadataManager &manager); - bool HasUnloadedDeletes() const; private: - mutex row_group_lock; + mutable mutex row_group_lock; vector column_pointers; - unique_ptr[]> is_loaded; + //! Whether or not each column is loaded (mutable because `const` can lazy load) + mutable unique_ptr[]> is_loaded; vector deletes_pointers; bool has_metadata_blocks = false; vector extra_metadata_blocks; atomic deletes_is_loaded; atomic allocation_size; - unique_ptr row_id_column_data; - atomic row_id_is_loaded; + //! The row id column data (mutable because `const` can lazy load) + mutable unique_ptr row_id_column_data; + //! Whether or not `row_id_column_data` is loaded (mutable because `const` can lazy load) + mutable atomic row_id_is_loaded; atomic has_changes; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index 32808ff4c..7437ea415 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -13,6 +13,7 @@ #include "duckdb/storage/statistics/column_statistics.hpp" #include "duckdb/storage/table/table_statistics.hpp" #include "duckdb/storage/storage_index.hpp" +#include "duckdb/common/enums/index_removal_type.hpp" namespace duckdb { @@ -36,6 +37,7 @@ struct CollectionCheckpointState; struct PersistentCollectionData; class CheckpointTask; class TableIOManager; +class DataTable; class RowGroupCollection { public: @@ -58,16 +60,19 @@ class RowGroupCollection { void AppendRowGroup(SegmentLock &l, idx_t start_row); //! Get the nth row-group, negative numbers start from the back (so -1 is the last row group, etc) optional_ptr GetRowGroup(int64_t index); + //! Overrides a row group - should only be used if you know what you're doing (will likely be removed in the future) + void SetRowGroup(int64_t index, shared_ptr new_row_group); void Verify(); void Destroy(); - void InitializeScan(CollectionScanState &state, const vector &column_ids, + void InitializeScan(const QueryContext &context, CollectionScanState &state, const vector &column_ids, optional_ptr table_filters); void InitializeCreateIndexScan(CreateIndexScanState &state); - void InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, idx_t start_row, - idx_t end_row); - static bool InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row); + void InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, idx_t end_row); + static bool InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, SegmentNode &row_group, + idx_t vector_index, idx_t max_row); void InitializeParallelScan(ParallelCollectionScanState &state); bool NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, CollectionScanState &scan_state); @@ -97,17 +102,18 @@ class RowGroupCollection { optional_ptr commit_state); bool IsPersistent() const; - void RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, Vector &row_identifiers, idx_t count, + IndexRemovalType removal_type); idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); - void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); - void UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates); + void Update(TransactionData transaction, DataTable &table, row_t *ids, const vector &column_ids, + DataChunk &updates); + void UpdateColumn(TransactionData transaction, DataTable &table, Vector &row_ids, + const vector &column_path, DataChunk &updates); void Checkpoint(TableDataWriter &writer, TableStatistics &global_stats); - void InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state, - vector> &segments); + void InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state); bool ScheduleVacuumTasks(CollectionCheckpointState &checkpoint_state, VacuumState &state, idx_t segment_idx, bool schedule_vacuum); unique_ptr GetCheckpointTask(CollectionCheckpointState &checkpoint_state, idx_t segment_idx); @@ -116,7 +122,7 @@ class RowGroupCollection { void CommitDropTable(); vector GetPartitionStats() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); const vector &GetTypes() const; shared_ptr AddColumn(ClientContext &context, ColumnDefinition &new_column, @@ -124,8 +130,9 @@ class RowGroupCollection { shared_ptr RemoveColumn(idx_t col_idx); shared_ptr AlterType(ClientContext &context, idx_t changed_idx, const LogicalType &target_type, vector bound_columns, Expression &cast_expr); - void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + void VerifyNewConstraint(const QueryContext &context, DataTable &parent, const BoundConstraint &constraint); + void SetStats(TableStatistics &new_stats); void CopyStats(TableStatistics &stats); unique_ptr CopyStats(column_t column_id); unique_ptr GetSample(); @@ -150,9 +157,11 @@ class RowGroupCollection { void SetAppendRequiresNewRowGroup(); private: - bool IsEmpty(SegmentLock &) const; + optional_ptr> NextUpdateRowGroup(RowGroupSegmentTree &row_groups, row_t *ids, idx_t &pos, + idx_t count) const; - optional_ptr NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const; + shared_ptr GetRowGroups() const; + void SetRowGroups(shared_ptr row_groups); private: //! BlockManager @@ -165,9 +174,10 @@ class RowGroupCollection { shared_ptr info; //! The column types of the row group collection vector types; - idx_t row_start; - //! The segment trees holding the various row_groups of the table - shared_ptr row_groups; + //! Lock held when accessing or modifying the owned_row_groups pointer + mutable mutex row_group_pointer_lock; + //! The owning pointer of the segment tree + shared_ptr owned_row_groups; //! Table statistics TableStatistics stats; //! Allocation size, only tracked for appends diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp new file mode 100644 index 000000000..c41b81cdb --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/row_group_reorderer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/partition_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/table/row_group.hpp" +#include "duckdb/storage/table/row_group_segment_tree.hpp" +#include "duckdb/storage/table/segment_tree.hpp" + +namespace duckdb { + +enum class OrderByStatistics { MIN, MAX }; +enum class RowGroupOrderType { ASC, DESC }; +enum class OrderByColumnType { NUMERIC, STRING }; + +struct RowGroupOrderOptions { + RowGroupOrderOptions(column_t column_idx_p, OrderByStatistics order_by_p, RowGroupOrderType order_type_p, + OrderByColumnType column_type_p, optional_idx row_limit_p = optional_idx(), + idx_t row_group_offset_p = 0) + : column_idx(column_idx_p), order_by(order_by_p), order_type(order_type_p), column_type(column_type_p), + row_limit(row_limit_p), row_group_offset(row_group_offset_p) { + } + + const column_t column_idx; + const OrderByStatistics order_by; + const RowGroupOrderType order_type; + const OrderByColumnType column_type; + const optional_idx row_limit; + const idx_t row_group_offset; +}; + +struct OffsetPruningResult { + idx_t offset_remainder; + idx_t pruned_row_group_count; +}; + +class RowGroupReorderer { +public: + explicit RowGroupReorderer(const RowGroupOrderOptions &options_p); + optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups); + optional_ptr> GetNextRowGroup(SegmentNode &row_group); + + static Value RetrieveStat(const BaseStatistics &stats, OrderByStatistics order_by, OrderByColumnType column_type); + static OffsetPruningResult GetOffsetAfterPruning(OrderByStatistics order_by, OrderByColumnType column_type, + RowGroupOrderType order_type, column_t column_idx, + idx_t row_offset, vector &stats); + +private: + const RowGroupOrderOptions options; + + idx_t offset; + bool initialized; + vector>> ordered_row_groups; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp index d93610138..5c2a91845 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp @@ -18,7 +18,7 @@ class MetadataReader; class RowGroupSegmentTree : public SegmentTree { public: - explicit RowGroupSegmentTree(RowGroupCollection &collection); + RowGroupSegmentTree(RowGroupCollection &collection, idx_t base_row_id); ~RowGroupSegmentTree() override; void Initialize(PersistentTableData &data); @@ -28,12 +28,12 @@ class RowGroupSegmentTree : public SegmentTree { } protected: - unique_ptr LoadSegment() override; + shared_ptr LoadSegment() const override; RowGroupCollection &collection; - idx_t current_row_group; - idx_t max_row_group; - unique_ptr reader; + mutable idx_t current_row_group; + mutable idx_t max_row_group; + mutable unique_ptr reader; MetaBlockPointer root_pointer; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp index 3bc6572ae..42ab24a8c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp @@ -14,7 +14,7 @@ namespace duckdb { class RowIdColumnData : public ColumnData { public: - RowIdColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t start_row); + RowIdColumnData(BlockManager &block_manager, DataTableInfo &info); public: void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; @@ -46,21 +46,22 @@ class RowIdColumnData : public ColumnData { void InitializeAppend(ColumnAppendState &state) override; void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; - void RevertAppend(row_t start_row) override; + void RevertAppend(row_t new_count) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; - void CommitDropColumn() override; + void VisitBlockIds(BlockIdVisitor &visitor) const override; - unique_ptr CreateCheckpointState(RowGroup &row_group, + unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; - unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; - void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector) override; + void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, + Vector &scan_vector) const override; bool IsPersistent() override; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp index bb0d0056b..8856ce57b 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp @@ -12,23 +12,25 @@ #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/common/mutex.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { struct DeleteInfo; class MetadataManager; +class BufferManager; struct MetaBlockPointer; class RowVersionManager { public: - explicit RowVersionManager(idx_t start) noexcept; + explicit RowVersionManager(BufferManager &buffer_manager) noexcept; - idx_t GetStart() { - return start; + FixedSizeAllocator &GetAllocator() { + return allocator; } - void SetStart(idx_t start); idx_t GetCommittedDeletedCount(idx_t count); + bool ShouldCheckpointRowGroup(transaction_t checkpoint_id, idx_t count); idx_t GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); idx_t GetCommittedSelVector(transaction_t start_time, transaction_t transaction_id, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); @@ -36,21 +38,23 @@ class RowVersionManager { void AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, idx_t row_group_end); void CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count); - void RevertAppend(idx_t start_row); + void RevertAppend(idx_t new_count); void CleanupAppend(transaction_t lowest_active_transaction, idx_t row_group_start, idx_t count); idx_t DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count); void CommitDelete(idx_t vector_idx, transaction_t commit_id, const DeleteInfo &info); vector Checkpoint(MetadataManager &manager); - static shared_ptr Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, - idx_t start); + static shared_ptr Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager); + + bool HasUnserializedChanges(); + vector GetStoragePointers(); private: mutex version_lock; - idx_t start; + FixedSizeAllocator allocator; vector> vector_info; - bool has_changes; + bool has_unserialized_changes; vector storage_pointers; private: diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp index 97416cbf2..3b17cfb68 100644 --- a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/map.hpp" #include "duckdb/storage/buffer/buffer_handle.hpp" #include "duckdb/storage/storage_lock.hpp" +#include "duckdb/storage/table/row_group_reorderer.hpp" #include "duckdb/common/enums/scan_options.hpp" #include "duckdb/common/random_engine.hpp" #include "duckdb/storage/table/segment_lock.hpp" @@ -42,6 +43,8 @@ struct AdaptiveFilterState; struct TableScanOptions; struct ScanSamplingInfo; struct TableFilterState; +template +struct SegmentNode; struct SegmentScanState { virtual ~SegmentScanState() { @@ -78,18 +81,24 @@ struct IndexScanState { typedef unordered_map buffer_handle_set_t; struct ColumnScanState { + explicit ColumnScanState(optional_ptr parent_p) : parent(parent_p) { + } + + optional_ptr parent; + //! The query context for this scan + QueryContext context; //! The column segment that is currently being scanned - ColumnSegment *current = nullptr; + optional_ptr> current; //! Column segment tree ColumnSegmentTree *segment_tree = nullptr; - //! The current row index of the scan - idx_t row_index = 0; + //! The current row offset in the column + idx_t offset_in_column = 0; //! The internal row index (i.e. the position of the SegmentScanState) idx_t internal_index = 0; //! Segment scan state unique_ptr scan_state; //! Child states of the vector - vector child_states; + unsafe_vector child_states; //! Whether or not InitializeState has been called for this segment bool initialized = false; //! If this segment has already been checked for skipping purposes @@ -105,20 +114,26 @@ struct ColumnScanState { optional_ptr scan_options; public: - void Initialize(const LogicalType &type, const vector &children, + void Initialize(const QueryContext &context_p, const LogicalType &type, const vector &children, optional_ptr options); - void Initialize(const LogicalType &type, optional_ptr options); + void Initialize(const QueryContext &context_p, const LogicalType &type, optional_ptr options); //! Move the scan state forward by "count" rows (including all child states) void Next(idx_t count); //! Move ONLY this state forward by "count" rows (i.e. not the child states) void NextInternal(idx_t count); + //! Returns the current row position in the segment + idx_t GetPositionInSegment() const; }; struct ColumnFetchState { + //! The query context for this fetch + QueryContext context; //! The set of pinned block handles for this set of fetches buffer_handle_set_t handles; //! Any child states of the fetch vector> child_states; + //! The current row group we are fetching from + optional_ptr> row_group; BufferHandle &GetOrInsertHandle(ColumnSegment &segment); }; @@ -183,15 +198,15 @@ class CollectionScanState { explicit CollectionScanState(TableScanState &parent_p); //! The current row_group we are scanning - RowGroup *row_group; + optional_ptr> row_group; //! The vector index within the row_group idx_t vector_index; //! The maximum row within the row group idx_t max_row_group_row; //! Child column scans - unsafe_unique_array column_scans; - //! Row group segment tree - RowGroupSegmentTree *row_groups; + unsafe_vector column_scans; + //! Row group segment tree we are scanning + shared_ptr row_groups; //! The total maximum row index idx_t max_row; //! The current batch index @@ -201,12 +216,18 @@ class CollectionScanState { RandomEngine random; + //! Optional state for custom row group ordering + unique_ptr reorderer; + public: - void Initialize(const vector &types); + void Initialize(const QueryContext &context, const vector &types); const vector &GetColumnIds(); ScanFilterInfo &GetFilterInfo(); ScanSamplingInfo &GetSamplingInfo(); TableScanOptions &GetOptions(); + optional_ptr> GetNextRowGroup(SegmentNode &row_group) const; + optional_ptr> GetNextRowGroup(SegmentLock &l, SegmentNode &row_group) const; + optional_ptr> GetRootSegment() const; bool Scan(DuckTransaction &transaction, DataChunk &result); bool ScanCommitted(DataChunk &result, TableScanType type); bool ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type); @@ -272,15 +293,22 @@ class TableScanState { struct ParallelCollectionScanState { ParallelCollectionScanState(); + optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups) const; + optional_ptr> GetNextRowGroup(RowGroupSegmentTree &row_groups, + SegmentNode &row_group) const; //! The row group collection we are scanning RowGroupCollection *collection; - RowGroup *current_row_group; + shared_ptr row_groups; + optional_ptr> current_row_group; idx_t vector_index; idx_t max_row; idx_t batch_index; atomic processed_rows; mutex lock; + + //! Optional state for custom row group ordering + unique_ptr reorderer; }; struct ParallelTableScanState { @@ -302,6 +330,7 @@ struct PrefetchState { class CreateIndexScanState : public TableScanState { public: + shared_ptr row_groups; vector> locks; unique_lock append_lock; SegmentLock segment_lock; diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp index b71587bf8..1a45c26e9 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_base.hpp @@ -16,28 +16,11 @@ namespace duckdb { template class SegmentBase { public: - SegmentBase(idx_t start, idx_t count) : start(start), count(count), next(nullptr) { - } - T *Next() { -#ifndef DUCKDB_R_BUILD - return next.load(); -#else - return next; -#endif + explicit SegmentBase(idx_t count) : count(count) { } - //! The start row id of this chunk - idx_t start; //! The amount of entries in this storage chunk atomic count; - //! The next segment after this one -#ifndef DUCKDB_R_BUILD - atomic next; -#else - T *next; -#endif - //! The index within the segment tree - idx_t index; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp index f427a5275..8633f6f97 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp @@ -19,19 +19,82 @@ namespace duckdb { template struct SegmentNode { + SegmentNode(idx_t row_start_p, shared_ptr node_p, idx_t index_p) + : row_start(row_start_p), node(std::move(node_p)), next(nullptr), index(index_p) { + } + +public: + optional_ptr> Next() const { +#ifndef DUCKDB_R_BUILD + return next.load(); +#else + return next; +#endif + } + + idx_t GetRowStart() const { + return row_start; + } + idx_t GetRowEnd() const { + return GetRowStart() + GetCount(); + } + idx_t GetCount() const { + return GetNode().count; + } + + idx_t GetIndex() const { + return index; + } + + T &GetNode() const { + return *node; + } + + shared_ptr MoveNode() { + return std::move(node); + } + shared_ptr &ReferenceNode() { + return node; + } + + bool HasNode() const { + return node.get(); + } + + void SetNext(optional_ptr> next) { + this->next = next.get(); + } + + void SetNode(shared_ptr new_node) { + node = std::move(new_node); + } + +private: idx_t row_start; - unique_ptr node; + shared_ptr node; + //! The next segment after this one +#ifndef DUCKDB_R_BUILD + atomic *> next; +#else + SegmentNode *next; +#endif + //! The index within the segment tree + idx_t index; }; //! The SegmentTree maintains a list of all segments of a specific column in a table, and allows searching for a segment //! by row number +// The const-ness of the SegmentTree is implemented in an odd manner due to the lazy loading +// in particular, most internal members are `mutable` - i.e. they can be internally modified even through `const` +// methods The reasoning this is implemented this way is that the lazy loading would otherwise template class SegmentTree { private: class SegmentIterationHelper; + class SegmentNodeIterationHelper; public: - explicit SegmentTree() : finished_loading(true) { + explicit SegmentTree(idx_t base_row_id = 0) : finished_loading(true), base_row_id(base_row_id) { } virtual ~SegmentTree() { } @@ -42,60 +105,50 @@ class SegmentTree { return SegmentLock(node_lock); } - bool IsEmpty(SegmentLock &l) { + bool IsEmpty(SegmentLock &l) const { return GetRootSegment(l) == nullptr; } //! Gets a pointer to the first segment. Useful for scans. - T *GetRootSegment() { + optional_ptr> GetRootSegment() const { auto l = Lock(); return GetRootSegment(l); } - T *GetRootSegment(SegmentLock &l) { + optional_ptr> GetRootSegment(SegmentLock &l) const { if (nodes.empty()) { LoadNextSegment(l); } return GetRootSegmentInternal(); } //! Obtains ownership of the data of the segment tree - vector> MoveSegments(SegmentLock &l) { + vector>> MoveSegments(SegmentLock &l) { LoadAllSegments(l); return std::move(nodes); } - vector> MoveSegments() { + vector>> MoveSegments() { auto l = Lock(); return MoveSegments(l); } - const vector> &ReferenceSegments(SegmentLock &l) { - LoadAllSegments(l); - return nodes; - } - const vector> &ReferenceSegments() { - auto l = Lock(); - return ReferenceSegments(l); - } - vector> &ReferenceLoadedSegmentsMutable(SegmentLock &l) { - return nodes; - } - const vector> &ReferenceLoadedSegments(SegmentLock &l) const { + vector>> &ReferenceLoadedSegmentsMutable(SegmentLock &l) { return nodes; } - idx_t GetSegmentCount() { + idx_t GetSegmentCount() const { auto l = Lock(); return GetSegmentCount(l); } idx_t GetSegmentCount(SegmentLock &l) const { + LoadAllSegments(l); return nodes.size(); } //! Gets a pointer to the nth segment. Negative numbers start from the back. - T *GetSegmentByIndex(int64_t index) { + optional_ptr> GetSegmentByIndex(int64_t index) const { auto l = Lock(); return GetSegmentByIndex(l, index); } - T *GetSegmentByIndex(SegmentLock &l, int64_t index) { + optional_ptr> GetSegmentByIndex(SegmentLock &l, int64_t index) const { if (index < 0) { // load all segments LoadAllSegments(l); @@ -103,7 +156,7 @@ class SegmentTree { if (index < 0) { return nullptr; } - return nodes[UnsafeNumericCast(index)].node.get(); + return nodes[UnsafeNumericCast(index)].get(); } else { // lazily load segments until we reach the specific segment while (idx_t(index) >= nodes.size() && LoadNextSegment(l)) { @@ -111,76 +164,64 @@ class SegmentTree { if (idx_t(index) >= nodes.size()) { return nullptr; } - return nodes[UnsafeNumericCast(index)].node.get(); + return nodes[UnsafeNumericCast(index)].get(); } } //! Gets the next segment - T *GetNextSegment(T *segment) { + optional_ptr> GetNextSegment(SegmentNode &node) const { if (!SUPPORTS_LAZY_LOADING) { - return segment->Next(); + return node.Next(); } if (finished_loading) { - return segment->Next(); + return node.Next(); } auto l = Lock(); - return GetNextSegment(l, segment); + return GetNextSegment(l, node); } - T *GetNextSegment(SegmentLock &l, T *segment) { - if (!segment) { - return nullptr; - } + optional_ptr> GetNextSegment(SegmentLock &l, SegmentNode &node) const { #ifdef DEBUG - D_ASSERT(nodes[segment->index].node.get() == segment); + D_ASSERT(RefersToSameObject(*nodes[node.GetIndex()], node)); #endif - return GetSegmentByIndex(l, UnsafeNumericCast(segment->index + 1)); + return GetSegmentByIndex(l, UnsafeNumericCast(node.GetIndex() + 1)); } //! Gets a pointer to the last segment. Useful for appends. - T *GetLastSegment(SegmentLock &l) { + optional_ptr> GetLastSegment(SegmentLock &l) const { LoadAllSegments(l); if (nodes.empty()) { return nullptr; } - return nodes.back().node.get(); + return nodes.back().get(); } //! Gets a pointer to a specific column segment for the given row - T *GetSegment(idx_t row_number) { + optional_ptr> GetSegment(idx_t row_number) const { auto l = Lock(); return GetSegment(l, row_number); } - T *GetSegment(SegmentLock &l, idx_t row_number) { - return nodes[GetSegmentIndex(l, row_number)].node.get(); + optional_ptr> GetSegment(SegmentLock &l, idx_t row_number) const { + return nodes[GetSegmentIndex(l, row_number)].get(); } - //! Append a column segment to the tree - void AppendSegmentInternal(SegmentLock &l, unique_ptr segment) { - D_ASSERT(segment); - // add the node to the list of nodes - if (!nodes.empty()) { - nodes.back().node->next = segment.get(); - } - SegmentNode node; - segment->index = nodes.size(); - segment->next = nullptr; - node.row_start = segment->start; - node.node = std::move(segment); - nodes.push_back(std::move(node)); - } - void AppendSegment(unique_ptr segment) { + void AppendSegment(shared_ptr segment) { auto l = Lock(); AppendSegment(l, std::move(segment)); } - void AppendSegment(SegmentLock &l, unique_ptr segment) { + void AppendSegment(SegmentLock &l, shared_ptr segment) { LoadAllSegments(l); AppendSegmentInternal(l, std::move(segment)); } + void AppendSegment(SegmentLock &l, shared_ptr segment, idx_t row_start) { + LoadAllSegments(l); + AppendSegmentInternal(l, std::move(segment), row_start); + } //! Debug method, check whether the segment is in the segment tree - bool HasSegment(T *segment) { + bool HasSegment(SegmentNode &segment) const { auto l = Lock(); return HasSegment(l, segment); } - bool HasSegment(SegmentLock &, T *segment) { - return segment->index < nodes.size() && nodes[segment->index].node.get() == segment; + bool HasSegment(SegmentLock &, SegmentNode &segment) const { + auto segment_idx = segment.GetIndex(); + return segment_idx < nodes.size() && RefersToSameObject(*nodes[segment_idx], segment); } //! Erase all segments after a specific segment @@ -193,7 +234,7 @@ class SegmentTree { } //! Get the segment index of the column segment for the given row - idx_t GetSegmentIndex(SegmentLock &l, idx_t row_number) { + idx_t GetSegmentIndex(SegmentLock &l, idx_t row_number) const { idx_t segment_index; if (TryGetSegmentIndex(l, row_number, segment_index)) { return segment_index; @@ -201,15 +242,15 @@ class SegmentTree { string error; error = StringUtil::Format("Attempting to find row number \"%lld\" in %lld nodes\n", row_number, nodes.size()); for (idx_t i = 0; i < nodes.size(); i++) { - error += StringUtil::Format("Node %lld: Start %lld, Count %lld", i, nodes[i].row_start, - nodes[i].node->count.load()); + error += StringUtil::Format("Node %lld: Start %lld, Count %lld", i, nodes[i]->GetRowStart(), + nodes[i]->GetCount()); } throw InternalException("Could not find node in column segment tree!\n%s", error); } - bool TryGetSegmentIndex(SegmentLock &l, idx_t row_number, idx_t &result) { + bool TryGetSegmentIndex(SegmentLock &l, idx_t row_number, idx_t &result) const { // load segments until the row number is within bounds - while (nodes.empty() || (row_number >= (nodes.back().row_start + nodes.back().node->count))) { + while (nodes.empty() || (row_number >= nodes.back()->GetRowEnd())) { if (!LoadNextSegment(l)) { break; } @@ -225,16 +266,15 @@ class SegmentTree { if (index >= nodes.size()) { string segments; for (auto &entry : nodes) { - segments += StringUtil::Format("Start %d Count %d", entry.row_start, entry.node->count.load()); + segments += StringUtil::Format("Start %d Count %d", entry->GetRowStart(), entry->GetCount()); } throw InternalException("Segment tree index not found for row number %d\nSegments:%s", row_number, segments); } - auto &entry = nodes[index]; - D_ASSERT(entry.row_start == entry.node->start); - if (row_number < entry.row_start) { + auto &entry = *nodes[index]; + if (row_number < entry.GetRowStart()) { upper = index - 1; - } else if (row_number >= entry.row_start + entry.node->count) { + } else if (row_number >= entry.GetRowEnd()) { lower = index + 1; } else { result = index; @@ -244,13 +284,12 @@ class SegmentTree { return false; } - void Verify(SegmentLock &) { + void Verify(SegmentLock &) const { #ifdef DEBUG - idx_t base_start = nodes.empty() ? 0 : nodes[0].node->start; + idx_t base_start = nodes.empty() ? 0 : nodes[0]->GetRowStart(); for (idx_t i = 0; i < nodes.size(); i++) { - D_ASSERT(nodes[i].row_start == nodes[i].node->start); - D_ASSERT(nodes[i].node->start == base_start); - base_start += nodes[i].node->count; + D_ASSERT(nodes[i]->GetRowStart() == base_start); + base_start += nodes[i]->GetCount(); } #endif } @@ -261,84 +300,125 @@ class SegmentTree { #endif } - SegmentIterationHelper Segments() { + idx_t GetBaseRowId() const { + return base_row_id; + } + + SegmentIterationHelper Segments() const { return SegmentIterationHelper(*this); } - SegmentIterationHelper Segments(SegmentLock &l) { + SegmentIterationHelper Segments(SegmentLock &l) const { return SegmentIterationHelper(*this, l); } - void Reinitialize() { - if (nodes.empty()) { - return; - } - idx_t offset = nodes[0].node->start; - for (auto &entry : nodes) { - if (entry.node->start != offset) { - throw InternalException("In SegmentTree::Reinitialize - gap found between nodes!"); - } - entry.row_start = offset; - offset += entry.node->count; - } + SegmentNodeIterationHelper SegmentNodes() const { + return SegmentNodeIterationHelper(*this); + } + + SegmentNodeIterationHelper SegmentNodes(SegmentLock &l) const { + return SegmentNodeIterationHelper(*this, l); } protected: - atomic finished_loading; + mutable atomic finished_loading; //! Load the next segment - only used when lazily loading - virtual unique_ptr LoadSegment() { + virtual shared_ptr LoadSegment() const { return nullptr; } - T *GetRootSegmentInternal() const { - return nodes.empty() ? nullptr : nodes[0].node.get(); + optional_ptr> GetRootSegmentInternal() const { + return nodes.empty() ? nullptr : nodes[0].get(); } private: //! The nodes in the tree, can be binary searched - vector> nodes; + mutable vector>> nodes; //! Lock to access or modify the nodes mutable mutex node_lock; + //! Base row id (row id of the first segment) + idx_t base_row_id; private: + class BaseSegmentIterator { + public: + BaseSegmentIterator(const SegmentTree &tree_p, optional_ptr> current_p, + optional_ptr lock) + : tree(tree_p), current(current_p), lock(lock) { + } + + const SegmentTree &tree; + optional_ptr> current; + optional_ptr lock; + + public: + void Next() { + current = lock ? tree.GetNextSegment(*lock, *current) : tree.GetNextSegment(*current); + } + + BaseSegmentIterator &operator++() { + Next(); + return *this; + } + bool operator!=(const BaseSegmentIterator &other) const { + return current != other.current; + } + }; class SegmentIterationHelper { public: - explicit SegmentIterationHelper(SegmentTree &tree) : tree(tree) { + explicit SegmentIterationHelper(const SegmentTree &tree) : tree(tree) { } - SegmentIterationHelper(SegmentTree &tree, SegmentLock &l) : tree(tree), lock(l) { + SegmentIterationHelper(const SegmentTree &tree, SegmentLock &l) : tree(tree), lock(l) { } private: - SegmentTree &tree; + const SegmentTree &tree; optional_ptr lock; private: - class SegmentIterator { + class SegmentIterator : public BaseSegmentIterator { public: - SegmentIterator(SegmentTree &tree_p, T *current_p, optional_ptr lock) - : tree(tree_p), current(current_p), lock(lock) { + SegmentIterator(const SegmentTree &tree_p, optional_ptr> current_p, + optional_ptr lock) + : BaseSegmentIterator(tree_p, current_p, lock) { } - SegmentTree &tree; - T *current; - optional_ptr lock; + T &operator*() const { + return BaseSegmentIterator::current->GetNode(); + } + }; + public: + SegmentIterator begin() { // NOLINT: match stl API + auto root = lock ? tree.GetRootSegment(*lock) : tree.GetRootSegment(); + return SegmentIterator(tree, root, lock); + } + SegmentIterator end() { // NOLINT: match stl API + return SegmentIterator(tree, nullptr, lock); + } + }; + class SegmentNodeIterationHelper { + public: + explicit SegmentNodeIterationHelper(const SegmentTree &tree) : tree(tree) { + } + SegmentNodeIterationHelper(const SegmentTree &tree, SegmentLock &l) : tree(tree), lock(l) { + } + + private: + const SegmentTree &tree; + optional_ptr lock; + + private: + class SegmentIterator : public BaseSegmentIterator { public: - void Next() { - current = lock ? tree.GetNextSegment(*lock, current) : tree.GetNextSegment(current); + SegmentIterator(const SegmentTree &tree_p, optional_ptr> current_p, + optional_ptr lock) + : BaseSegmentIterator(tree_p, current_p, lock) { } - SegmentIterator &operator++() { - Next(); - return *this; - } - bool operator!=(const SegmentIterator &other) const { - return current != other.current; - } - T &operator*() const { - D_ASSERT(current); - return *current; + SegmentNode &operator*() { + return *BaseSegmentIterator::current; } }; @@ -353,7 +433,7 @@ class SegmentTree { }; //! Load the next segment, if there are any left to load - bool LoadNextSegment(SegmentLock &l) { + bool LoadNextSegment(SegmentLock &l) const { if (!SUPPORTS_LAZY_LOADING) { return false; } @@ -369,13 +449,34 @@ class SegmentTree { } //! Load all segments, if there are any left to load - void LoadAllSegments(SegmentLock &l) { + void LoadAllSegments(SegmentLock &l) const { if (!SUPPORTS_LAZY_LOADING) { return; } while (LoadNextSegment(l)) { } } + + //! Append a column segment to the tree + void AppendSegmentInternal(SegmentLock &l, shared_ptr segment, idx_t row_start) const { + D_ASSERT(segment); + // add the node to the list of nodes + auto node = make_uniq>(row_start, std::move(segment), nodes.size()); + if (!nodes.empty()) { + nodes.back()->SetNext(*node); + } + nodes.push_back(std::move(node)); + } + void AppendSegmentInternal(SegmentLock &l, shared_ptr segment) const { + idx_t row_start; + if (nodes.empty()) { + row_start = base_row_id; + } else { + auto &last_node = nodes.back(); + row_start = last_node->GetRowEnd(); + } + AppendSegmentInternal(l, std::move(segment), row_start); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 48ac6ccb7..9b8a9509a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -16,14 +16,11 @@ namespace duckdb { //! Standard column data represents a regular flat column (e.g. a column of type INTEGER or STRING) class StandardColumnData : public ColumnData { public: - StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type, optional_ptr parent = nullptr); - - //! The validity column data - ValidityColumnData validity; + StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); public: - void SetStart(idx_t new_start) override; + void SetDataType(ColumnDataType data_type) override; ScanVectorType GetVectorScanType(ColumnScanState &state, idx_t scan_count, Vector &result) override; void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; @@ -43,26 +40,27 @@ class StandardColumnData : public ColumnData { void InitializeAppend(ColumnAppendState &state) override; void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; - void RevertAppend(row_t start_row) override; + void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; unique_ptr GetUpdateStatistics() override; - void CommitDropColumn() override; + void VisitBlockIds(BlockIdVisitor &visitor) const override; - unique_ptr CreateCheckpointState(RowGroup &row_group, + unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; - unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; - void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; + void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, + Vector &scan_vector) const override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; bool IsPersistent() override; bool HasAnyChanges() const override; @@ -70,6 +68,12 @@ class StandardColumnData : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void Verify(RowGroup &parent) override; + + void SetValidityData(shared_ptr validity); + +protected: + //! The validity column data + shared_ptr validity; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index d05436bfc..64b74fc42 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -16,16 +16,11 @@ namespace duckdb { //! Struct column data represents a struct class StructColumnData : public ColumnData { public: - StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type, optional_ptr parent = nullptr); - - //! The sub-columns of the struct - vector> sub_columns; - //! The validity column data of the struct - ValidityColumnData validity; + StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); public: - void SetStart(idx_t new_start) override; + void SetDataType(ColumnDataType data_type) override; idx_t GetMaxEntry() override; void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; @@ -42,31 +37,41 @@ class StructColumnData : public ColumnData { void InitializeAppend(ColumnAppendState &state) override; void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; - void RevertAppend(row_t start_row) override; + void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; unique_ptr GetUpdateStatistics() override; - void CommitDropColumn() override; + void VisitBlockIds(BlockIdVisitor &visitor) const override; - unique_ptr CreateCheckpointState(RowGroup &row_group, + unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; - unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; + + void SetValidityData(shared_ptr validity_p); + void SetChildData(idx_t i, shared_ptr child_column_p); + +protected: + //! The sub-columns of the struct + vector> sub_columns; + //! The validity column data of the struct + shared_ptr validity; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp index b2cad5cca..779dbd331 100644 --- a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp @@ -26,8 +26,12 @@ enum class IndexBindState : uint8_t { UNBOUND, BINDING, BOUND }; //! IndexEntry contains an atomic in addition to the index to ensure correct binding. struct IndexEntry { explicit IndexEntry(unique_ptr index); + atomic bind_state; + //! lock that should be used if access to "index" and "deleted_rows_in_use" at the same time is necessary + mutex lock; unique_ptr index; + unique_ptr deleted_rows_in_use; }; class TableIndexList { @@ -43,6 +47,15 @@ class TableIndexList { } } + template + void ScanEntries(T &&callback) { + lock_guard lock(index_entries_lock); + for (auto &entry : index_entries) { + if (callback(*entry)) { + break; + } + } + } //! Adds an index entry to the list of index entries. void AddIndex(unique_ptr index); //! Removes an index entry from the list of index entries. diff --git a/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp index 3bb712f8b..af1e1667d 100644 --- a/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp @@ -32,6 +32,7 @@ class TableStatistics { public: void Initialize(const vector &types, PersistentTableData &data); void InitializeEmpty(const vector &types); + void InitializeEmpty(const TableStatistics &other); void InitializeAddColumn(TableStatistics &parent, const LogicalType &new_column_type); void InitializeRemoveColumn(TableStatistics &parent, idx_t removed_column); @@ -42,6 +43,7 @@ class TableStatistics { void MergeStats(idx_t i, BaseStatistics &stats); void MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats); + void SetStats(TableStatistics &other); void CopyStats(TableStatistics &other); void CopyStats(TableStatisticsLock &lock, TableStatistics &other); unique_ptr CopyStats(idx_t i); diff --git a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp index 75cf25ecf..6cc88457c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp @@ -38,8 +38,8 @@ class UpdateSegment { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result); void FetchCommitted(idx_t vector_index, Vector &result); void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); - void Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, - Vector &base_data); + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update, row_t *ids, + idx_t count, Vector &base_data, idx_t row_group_start); void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); void RollbackUpdate(UpdateInfo &info); @@ -70,7 +70,7 @@ class UpdateSegment { UnifiedVectorFormat &update, const SelectionVector &sel); typedef void (*merge_update_function_t)(UpdateInfo &base_info, Vector &base_data, UpdateInfo &update_info, UnifiedVectorFormat &update, row_t *ids, idx_t count, - const SelectionVector &sel); + const SelectionVector &sel, idx_t row_group_start); typedef void (*fetch_update_function_t)(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, Vector &result); typedef void (*fetch_committed_function_t)(UpdateInfo &info, Vector &result); diff --git a/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp index 286a5343b..c2f065953 100644 --- a/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp @@ -17,12 +17,19 @@ class ValidityColumnData : public ColumnData { friend class StandardColumnData; public: - ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - ColumnData &parent); + ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, ColumnData &parent); + ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, ColumnDataType data_type, + optional_ptr parent); public: FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + unique_ptr CreateCheckpointState(const RowGroup &row_group, + PartialBlockManager &partial_block_manager) override; + + void Verify(RowGroup &parent) override; + void UpdateWithBase(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, ColumnData &base, idx_t row_group_start); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp new file mode 100644 index 000000000..d2756dec2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/variant_column_data.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/validity_column_data.hpp" + +namespace duckdb { + +//! Struct column data represents a struct +class VariantColumnData : public ColumnData { +public: + VariantColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type, + ColumnDataType data_type, optional_ptr parent); + + //! The sub-columns of the struct + vector> sub_columns; + shared_ptr validity; + +public: + idx_t GetMaxEntry() override; + bool IsShredded() const { + return sub_columns.size() == 2; + } + + void InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) override; + void InitializeScan(ColumnScanState &state) override; + void InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) override; + + Vector CreateUnshreddingIntermediate(idx_t count); + idx_t Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + idx_t scan_count) override; + idx_t ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates, + idx_t scan_count) override; + idx_t ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset = 0) override; + + void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; + + void InitializeAppend(ColumnAppendState &state) override; + void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void RevertAppend(row_t new_count) override; + idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; + void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) override; + unique_ptr GetUpdateStatistics() override; + + void VisitBlockIds(BlockIdVisitor &visitor) const override; + + unique_ptr CreateCheckpointState(const RowGroup &row_group, + PartialBlockManager &partial_block_manager) override; + unique_ptr Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) override; + + bool IsPersistent() override; + bool HasAnyChanges() const override; + PersistentColumnData Serialize() override; + void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; + + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; + + void Verify(RowGroup &parent) override; + + static void ShredVariantData(Vector &input, Vector &output, idx_t count); + static void UnshredVariantData(Vector &input, Vector &output, idx_t count); + + void SetValidityData(shared_ptr validity_p); + void SetChildData(vector> child_data); + +private: + vector> WriteShreddedData(const RowGroup &row_group, const LogicalType &shredded_type, + BaseStatistics &stats); + void CreateScanStates(ColumnScanState &state); + LogicalType GetShreddedType(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp b/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp index 493dd320a..f733c50ec 100644 --- a/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table_storage_info.hpp @@ -13,6 +13,7 @@ #include "duckdb/storage/block.hpp" #include "duckdb/storage/index_storage_info.hpp" #include "duckdb/storage/storage_info.hpp" +#include "duckdb/storage/table/column_data.hpp" #include "duckdb/common/optional_idx.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp index c35820836..7eb828d57 100644 --- a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp +++ b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp @@ -39,6 +39,7 @@ class WriteAheadLogDeserializer; struct PersistentCollectionData; enum class WALInitState { NO_WAL, UNINITIALIZED, UNINITIALIZED_REQUIRES_TRUNCATE, INITIALIZED }; +enum class WALReplayState { MAIN_WAL, CHECKPOINT_WAL }; //! The WriteAheadLog (WAL) is a log that is used to provide durability. Prior //! to committing a transaction it writes the changes the transaction made to @@ -47,18 +48,21 @@ enum class WALInitState { NO_WAL, UNINITIALIZED, UNINITIALIZED_REQUIRES_TRUNCATE class WriteAheadLog { public: //! Initialize the WAL in the specified directory - explicit WriteAheadLog(AttachedDatabase &database, const string &wal_path, idx_t wal_size = 0ULL, - WALInitState state = WALInitState::NO_WAL); + explicit WriteAheadLog(StorageManager &storage_manager, const string &wal_path, idx_t wal_size = 0ULL, + WALInitState state = WALInitState::NO_WAL, + optional_idx checkpoint_iteration = optional_idx()); virtual ~WriteAheadLog(); public: - //! Replay and initialize the WAL - static unique_ptr Replay(FileSystem &fs, AttachedDatabase &database, const string &wal_path); + //! Replay and initialize the WAL, QueryContext is passed for metric collection purposes only!! + static unique_ptr Replay(QueryContext context, StorageManager &storage_manager, + const string &wal_path); AttachedDatabase &GetDatabase(); - //! Gets the total bytes written to the WAL since startup - idx_t GetWALSize() const; + const string &GetPath() const { + return wal_path; + } //! Gets the total bytes written to the WAL since startup idx_t GetTotalWritten() const; @@ -114,22 +118,23 @@ class WriteAheadLog { //! Truncate the WAL to a previous size, and clear anything currently set in the writer void Truncate(idx_t size); - //! Delete the WAL file on disk. The WAL should not be used after this point. - void Delete(); void Flush(); void WriteCheckpoint(MetaBlockPointer meta_block); protected: - static unique_ptr ReplayInternal(AttachedDatabase &database, unique_ptr handle); + //! Internally replay all WAL entries. QueryContext is passed for metric collection purposes only!! + static unique_ptr ReplayInternal(QueryContext context, StorageManager &storage_manager, + unique_ptr handle, + WALReplayState replay_state = WALReplayState::MAIN_WAL); protected: - AttachedDatabase &database; + StorageManager &storage_manager; mutex wal_lock; unique_ptr writer; string wal_path; - atomic wal_size; atomic init_state; + optional_idx checkpoint_iteration; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp index 0de2faabd..0d91a9dc0 100644 --- a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp @@ -11,6 +11,8 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/transaction/commit_state.hpp" namespace duckdb { @@ -21,29 +23,23 @@ struct UpdateInfo; class CleanupState { public: - explicit CleanupState(transaction_t lowest_active_transaction); - ~CleanupState(); - - // all tables with indexes that possibly need a vacuum (after e.g. a delete) - unordered_map> indexed_tables; + explicit CleanupState(const QueryContext &context, transaction_t lowest_active_transaction, + ActiveTransactionState transaction_state); public: void CleanupEntry(UndoFlags type, data_ptr_t data); private: + QueryContext context; //! Lowest active transaction transaction_t lowest_active_transaction; - // data for index cleanup - optional_ptr current_table; - DataChunk chunk; - row_t row_numbers[STANDARD_VECTOR_SIZE]; - idx_t count; + ActiveTransactionState transaction_state; + //! While cleaning up, we remove data from any delta indexes we added data to during the commit + IndexDataRemover index_data_remover; private: void CleanupDelete(DeleteInfo &info); void CleanupUpdate(UpdateInfo &info); - - void Flush(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/commit_state.hpp b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp index 382de7296..3ad975925 100644 --- a/src/duckdb/src/include/duckdb/transaction/commit_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp @@ -10,6 +10,8 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/vector_size.hpp" +#include "duckdb/common/enums/index_removal_type.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { class CatalogEntry; @@ -19,23 +21,52 @@ class WriteAheadLog; class ClientContext; struct DataTableInfo; +class DataTable; struct DeleteInfo; struct UpdateInfo; +enum class CommitMode { COMMIT, REVERT_COMMIT }; + +struct IndexDataRemover { +public: + explicit IndexDataRemover(QueryContext context, IndexRemovalType removal_type); + + void PushDelete(DeleteInfo &info); + void Verify(); + +private: + void Flush(DataTable &table, row_t *row_numbers, idx_t count); + +private: + // data for index cleanup + QueryContext context; + //! While committing, we remove data from any indexes that was deleted + IndexRemovalType removal_type; + DataChunk chunk; + //! Debug mode only - list of indexes to verify + reference_map_t> verify_indexes; +}; + class CommitState { public: - explicit CommitState(DuckTransaction &transaction, transaction_t commit_id); + explicit CommitState(DuckTransaction &transaction, transaction_t commit_id, + ActiveTransactionState transaction_state, CommitMode commit_mode); public: void CommitEntry(UndoFlags type, data_ptr_t data); void RevertCommit(UndoFlags type, data_ptr_t data); + void Flush(); + void Verify(); + static IndexRemovalType GetIndexRemovalType(ActiveTransactionState transaction_state, CommitMode commit_mode); private: void CommitEntryDrop(CatalogEntry &entry, data_ptr_t extra_data); + void CommitDelete(DeleteInfo &info); private: DuckTransaction &transaction; transaction_t commit_id; + IndexDataRemover index_data_remover; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp index 12c4d180c..f0169b3b8 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/reference_map.hpp" #include "duckdb/common/error_data.hpp" #include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/common/enums/active_transaction_state.hpp" namespace duckdb { class CheckpointLock; @@ -23,6 +24,11 @@ class StorageCommitState; struct DataTableInfo; struct UndoBufferProperties; +struct CommitInfo { + transaction_t commit_id; + ActiveTransactionState active_transactions = ActiveTransactionState::UNSET; +}; + class DuckTransaction : public Transaction { public: DuckTransaction(DuckTransactionManager &manager, ClientContext &context, transaction_t start_time, @@ -35,14 +41,12 @@ class DuckTransaction : public Transaction { transaction_t transaction_id; //! The commit id of this transaction, if it has successfully been committed transaction_t commit_id; - //! Highest active query when the transaction finished, used for cleaning up - transaction_t highest_active_query; atomic catalog_version; //! Transactions undergo Cleanup, after (1) removing them directly in RemoveTransaction, - //! or (2) after they exist old_transactions. - //! Some (after rollback) enter old_transactions, but do not require Cleanup. + //! or (2) after they enter cleanup_queue. + //! Some (after rollback) enter cleanup_queue, but do not require Cleanup. bool awaiting_cleanup; public: @@ -53,13 +57,14 @@ class DuckTransaction : public Transaction { void PushCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data, idx_t extra_data_size); void PushAttach(AttachedDatabase &db); - void SetReadWrite() override; + void SetModifications(DatabaseModificationType type) override; bool ShouldWriteToWAL(AttachedDatabase &db); - ErrorData WriteToWAL(AttachedDatabase &db, unique_ptr &commit_state) noexcept; + ErrorData WriteToWAL(ClientContext &context, AttachedDatabase &db, + unique_ptr &commit_state) noexcept; //! Commit the current transaction with the given commit identifier. Returns an error message if the transaction //! commit failed, or an empty string if the commit was sucessful - ErrorData Commit(AttachedDatabase &db, transaction_t commit_id, + ErrorData Commit(AttachedDatabase &db, CommitInfo &commit_info, unique_ptr commit_state) noexcept; //! Returns whether or not a commit of this transaction should trigger an automatic checkpoint bool AutomaticCheckpoint(AttachedDatabase &db, const UndoBufferProperties &properties); @@ -76,8 +81,9 @@ class DuckTransaction : public Transaction { idx_t base_row); void PushSequenceUsage(SequenceCatalogEntry &entry, const SequenceData &data); void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); - UndoBufferReference CreateUpdateInfo(idx_t type_size, idx_t entries); + UndoBufferReference CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries, idx_t row_group_start); + DuckTransactionManager &GetTransactionManager(); bool IsDuckTransaction() const override { return true; } @@ -90,10 +96,10 @@ class DuckTransaction : public Transaction { //! Get a shared lock on a table shared_ptr SharedLockTable(DataTableInfo &info); + //! Hold an owning reference of the table, needed to safely reference it inside the transaction commit/undo logic void ModifyTable(DataTable &tbl); private: - DuckTransactionManager &transaction_manager; //! The undo buffer is used to store old versions of rows that are updated //! or deleted UndoBuffer undo_buffer; diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp index 63531ae7d..42ff94731 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp @@ -14,6 +14,7 @@ #include "duckdb/common/queue.hpp" namespace duckdb { +class DuckTransactionManager; class DuckTransaction; struct UndoBufferProperties; @@ -28,6 +29,13 @@ struct DuckCleanupInfo { bool ScheduleCleanup() noexcept; }; +struct ActiveCheckpointWrapper { + explicit ActiveCheckpointWrapper(DuckTransactionManager &manager); + ~ActiveCheckpointWrapper(); + + DuckTransactionManager &manager; +}; + //! The Transaction Manager is responsible for creating and managing //! transactions class DuckTransactionManager : public TransactionManager { @@ -56,6 +64,11 @@ class DuckTransactionManager : public TransactionManager { transaction_t GetLastCommit() const { return last_commit; } + transaction_t GetActiveCheckpoint() const { + return active_checkpoint; + } + transaction_t GetNewCheckpointId(); + void ResetCheckpointId(); bool IsDuckTransactionManager() override { return true; @@ -63,6 +76,8 @@ class DuckTransactionManager : public TransactionManager { //! Obtains a shared lock to the checkpoint lock unique_ptr SharedCheckpointLock(); + //! Try to obtain an exclusive checkpoint lock + unique_ptr TryGetCheckpointLock(); unique_ptr TryUpgradeCheckpointLock(StorageLockKey &lock); //! Returns the current version of the catalog (incremented whenever anything changes, not stored between restarts) @@ -94,6 +109,7 @@ class DuckTransactionManager : public TransactionManager { //! Whether or not we can checkpoint CheckpointDecision CanCheckpoint(DuckTransaction &transaction, unique_ptr &checkpoint_lock, const UndoBufferProperties &properties); + bool HasOtherTransactions(DuckTransaction &transaction); private: //! The current start timestamp used by transactions @@ -106,20 +122,18 @@ class DuckTransactionManager : public TransactionManager { atomic lowest_active_start; //! The last commit timestamp atomic last_commit; + //! The currently active checkpoint + atomic active_checkpoint; //! Set of currently running transactions vector> active_transactions; //! Set of recently committed transactions vector> recently_committed_transactions; - //! Transactions awaiting GC - vector> old_transactions; //! The lock used for transaction operations mutex transaction_lock; //! The checkpoint lock StorageLock checkpoint_lock; //! Lock necessary to start transactions only - used by FORCE CHECKPOINT to prevent new transactions from starting mutex start_transaction_lock; - //! Mutex used to control writes to the WAL - separate from the transaction lock - mutex wal_lock; atomic last_uncommitted_catalog_version = {TRANSACTION_ID_START}; idx_t last_committed_version = 0; diff --git a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp index 5d29da46c..1cab839d0 100644 --- a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp +++ b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp @@ -40,6 +40,8 @@ class LocalTableStorage : public enable_shared_from_this { ExpressionExecutor &default_executor); ~LocalTableStorage(); + QueryContext context; + reference table_ref; Allocator &allocator; @@ -189,6 +191,10 @@ class LocalStorage { void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + ClientContext &GetClientContext() const { + return context; + } + private: ClientContext &context; DuckTransaction &transaction; diff --git a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp index 71693ee14..8209963a4 100644 --- a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp @@ -21,6 +21,7 @@ namespace duckdb { class AttachedDatabase; class ClientContext; +struct DatabaseModificationType; class Transaction; enum class TransactionState { UNCOMMITTED, COMMITTED, ROLLED_BACK }; @@ -68,7 +69,7 @@ class MetaTransaction { void SetReadOnly(); bool IsReadOnly() const; - void ModifyDatabase(AttachedDatabase &db); + void ModifyDatabase(AttachedDatabase &db, DatabaseModificationType modification); optional_ptr ModifiedDatabase() { return modified_database; } diff --git a/src/duckdb/src/include/duckdb/transaction/transaction.hpp b/src/duckdb/src/include/duckdb/transaction/transaction.hpp index db0a90eee..6fc519696 100644 --- a/src/duckdb/src/include/duckdb/transaction/transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/transaction.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/atomic.hpp" namespace duckdb { +class Catalog; class SequenceCatalogEntry; class SchemaCatalogEntry; @@ -33,6 +34,7 @@ class ChunkVectorInfo; struct DeleteInfo; struct UpdateInfo; +struct DatabaseModificationType; //! The transaction object holds information about a currently running or past //! transaction @@ -57,6 +59,8 @@ class Transaction { DUCKDB_API bool IsReadOnly(); //! Promotes the transaction to a read-write transaction DUCKDB_API virtual void SetReadWrite(); + //! Sets the database modifications that are planned to be performed in this transaction + DUCKDB_API virtual void SetModifications(DatabaseModificationType type); virtual bool IsDuckTransaction() const { return false; diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp index c117ab09e..c9bef2daf 100644 --- a/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp +++ b/src/duckdb/src/include/duckdb/transaction/transaction_manager.hpp @@ -53,6 +53,18 @@ class TransactionManager { protected: //! The attached database AttachedDatabase &db; + +public: + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + D_ASSERT(dynamic_cast(this)); + return reinterpret_cast(*this); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp b/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp index 7218bb876..b4e434265 100644 --- a/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp +++ b/src/duckdb/src/include/duckdb/transaction/undo_buffer.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/enums/undo_flags.hpp" #include "duckdb/transaction/undo_buffer_allocator.hpp" +#include "duckdb/common/enums/active_transaction_state.hpp" namespace duckdb { class BufferManager; @@ -18,6 +19,7 @@ class DuckTransaction; class StorageCommitState; class WriteAheadLog; struct UndoBufferPointer; +struct CommitInfo; struct UndoBufferProperties { idx_t estimated_size = 0; @@ -38,6 +40,7 @@ class UndoBuffer { optional_ptr current; data_ptr_t start; data_ptr_t end; + bool started = false; }; public: @@ -54,7 +57,7 @@ class UndoBuffer { //! Commit the changes made in the UndoBuffer: should be called on commit void WriteToWAL(WriteAheadLog &wal, optional_ptr commit_state); //! Commit the changes made in the UndoBuffer: should be called on commit - void Commit(UndoBuffer::IteratorState &iterator_state, transaction_t commit_id); + void Commit(UndoBuffer::IteratorState &iterator_state, CommitInfo &info); //! Revert committed changes made in the UndoBuffer up until the currently committed state void RevertCommit(UndoBuffer::IteratorState &iterator_state, transaction_t transaction_id); //! Rollback the changes made in this UndoBuffer: should be called on @@ -64,6 +67,7 @@ class UndoBuffer { private: DuckTransaction &transaction; UndoBufferAllocator allocator; + ActiveTransactionState active_transaction_state = ActiveTransactionState::UNSET; private: template diff --git a/src/duckdb/src/include/duckdb/transaction/update_info.hpp b/src/duckdb/src/include/duckdb/transaction/update_info.hpp index 7cccd923e..cde47e2b4 100644 --- a/src/duckdb/src/include/duckdb/transaction/update_info.hpp +++ b/src/duckdb/src/include/duckdb/transaction/update_info.hpp @@ -17,6 +17,7 @@ namespace duckdb { class UpdateSegment; struct DataTableInfo; +class DataTable; //! UpdateInfo is a class that represents a set of updates applied to a single vector. //! The UpdateInfo struct contains metadata associated with the update. @@ -26,8 +27,12 @@ struct DataTableInfo; struct UpdateInfo { //! The update segment that this update info affects UpdateSegment *segment; + //! The table this was update was made on + DataTable *table; //! The column index of which column we are updating idx_t column_index; + //! The start index of the row group + idx_t row_group_start; //! The version number atomic version_number; //! The vector index within the uncompressed segment @@ -87,7 +92,8 @@ struct UpdateInfo { //! Returns the total allocation size for an UpdateInfo entry, together with space for the tuple data static idx_t GetAllocSize(idx_t type_size); //! Initialize an UpdateInfo struct that has been allocated using GetAllocSize (i.e. has extra space after it) - static void Initialize(UpdateInfo &info, transaction_t transaction_id); + static void Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id, + idx_t row_group_start); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp index aad1a672c..4c68da487 100644 --- a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp @@ -31,7 +31,7 @@ class WALWriteState { void CommitEntry(UndoFlags type, data_ptr_t data); private: - void SwitchTable(DataTableInfo *table, UndoFlags new_op); + void SwitchTable(DataTableInfo &table, UndoFlags new_op); void WriteCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data); void WriteDelete(DeleteInfo &info); diff --git a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp index a60abf187..77fed9815 100644 --- a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp +++ b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp @@ -85,6 +85,8 @@ class StatementVerifier { private: const vector> empty_select_list = {}; + + const vector> &GetSelectList(QueryNode &node); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index 7c5136059..5434aaa2d 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -544,6 +544,7 @@ typedef struct { duckdb_state (*duckdb_appender_create_query)(duckdb_connection connection, const char *query, idx_t column_count, duckdb_logical_type *types, const char *table_name, const char **column_names, duckdb_appender *out_appender); + duckdb_state (*duckdb_appender_clear)(duckdb_appender appender); #endif // New arrow interface functions @@ -560,6 +561,82 @@ typedef struct { void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); #endif +// New functions for interacting with catalog entries +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_catalog (*duckdb_client_context_get_catalog)(duckdb_client_context context, const char *catalog_name); + const char *(*duckdb_catalog_get_type_name)(duckdb_catalog catalog); + duckdb_catalog_entry (*duckdb_catalog_get_entry)(duckdb_catalog catalog, duckdb_client_context context, + duckdb_catalog_entry_type entry_type, const char *schema_name, + const char *entry_name); + void (*duckdb_destroy_catalog)(duckdb_catalog *catalog); + duckdb_catalog_entry_type (*duckdb_catalog_entry_get_type)(duckdb_catalog_entry entry); + const char *(*duckdb_catalog_entry_get_name)(duckdb_catalog_entry entry); + void (*duckdb_destroy_catalog_entry)(duckdb_catalog_entry *entry); +#endif + +// New configuration options functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_config_option (*duckdb_create_config_option)(); + void (*duckdb_destroy_config_option)(duckdb_config_option *option); + void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); + void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); + void (*duckdb_config_option_set_default_value)(duckdb_config_option option, duckdb_value default_value); + void (*duckdb_config_option_set_default_scope)(duckdb_config_option option, + duckdb_config_option_scope default_scope); + void (*duckdb_config_option_set_description)(duckdb_config_option option, const char *description); + duckdb_state (*duckdb_register_config_option)(duckdb_connection connection, duckdb_config_option option); + duckdb_value (*duckdb_client_context_get_config_option)(duckdb_client_context context, const char *name, + duckdb_config_option_scope *out_scope); +#endif + +// API to define custom copy functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_copy_function (*duckdb_create_copy_function)(); + void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); + void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, + duckdb_delete_callback_t destructor); + duckdb_state (*duckdb_register_copy_function)(duckdb_connection connection, duckdb_copy_function copy_function); + void (*duckdb_destroy_copy_function)(duckdb_copy_function *copy_function); + void (*duckdb_copy_function_set_bind)(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind); + void (*duckdb_copy_function_bind_set_error)(duckdb_copy_function_bind_info info, const char *error); + void *(*duckdb_copy_function_bind_get_extra_info)(duckdb_copy_function_bind_info info); + duckdb_client_context (*duckdb_copy_function_bind_get_client_context)(duckdb_copy_function_bind_info info); + idx_t (*duckdb_copy_function_bind_get_column_count)(duckdb_copy_function_bind_info info); + duckdb_logical_type (*duckdb_copy_function_bind_get_column_type)(duckdb_copy_function_bind_info info, + idx_t col_idx); + duckdb_value (*duckdb_copy_function_bind_get_options)(duckdb_copy_function_bind_info info); + void (*duckdb_copy_function_bind_set_bind_data)(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor); + void (*duckdb_copy_function_set_global_init)(duckdb_copy_function copy_function, + duckdb_copy_function_global_init_t init); + void (*duckdb_copy_function_global_init_set_error)(duckdb_copy_function_global_init_info info, const char *error); + void *(*duckdb_copy_function_global_init_get_extra_info)(duckdb_copy_function_global_init_info info); + duckdb_client_context (*duckdb_copy_function_global_init_get_client_context)( + duckdb_copy_function_global_init_info info); + void *(*duckdb_copy_function_global_init_get_bind_data)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_global_init_set_global_state)(duckdb_copy_function_global_init_info info, + void *global_state, duckdb_delete_callback_t destructor); + const char *(*duckdb_copy_function_global_init_get_file_path)(duckdb_copy_function_global_init_info info); + void (*duckdb_copy_function_set_sink)(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function); + void (*duckdb_copy_function_sink_set_error)(duckdb_copy_function_sink_info info, const char *error); + void *(*duckdb_copy_function_sink_get_extra_info)(duckdb_copy_function_sink_info info); + duckdb_client_context (*duckdb_copy_function_sink_get_client_context)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_bind_data)(duckdb_copy_function_sink_info info); + void *(*duckdb_copy_function_sink_get_global_state)(duckdb_copy_function_sink_info info); + void (*duckdb_copy_function_set_finalize)(duckdb_copy_function copy_function, + duckdb_copy_function_finalize_t finalize); + void (*duckdb_copy_function_finalize_set_error)(duckdb_copy_function_finalize_info info, const char *error); + void *(*duckdb_copy_function_finalize_get_extra_info)(duckdb_copy_function_finalize_info info); + duckdb_client_context (*duckdb_copy_function_finalize_get_client_context)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_bind_data)(duckdb_copy_function_finalize_info info); + void *(*duckdb_copy_function_finalize_get_global_state)(duckdb_copy_function_finalize_info info); + void (*duckdb_copy_function_set_copy_from_function)(duckdb_copy_function copy_function, + duckdb_table_function table_function); + idx_t (*duckdb_table_function_bind_get_result_column_count)(duckdb_bind_info info); + const char *(*duckdb_table_function_bind_get_result_column_name)(duckdb_bind_info info, idx_t col_idx); + duckdb_logical_type (*duckdb_table_function_bind_get_result_column_type)(duckdb_bind_info info, idx_t col_idx); +#endif + // New functions for duckdb error data #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -600,6 +677,18 @@ typedef struct { int64_t (*duckdb_file_handle_size)(duckdb_file_handle file_handle); #endif +// API to register a custom log storage. +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_log_storage (*duckdb_create_log_storage)(); + void (*duckdb_destroy_log_storage)(duckdb_log_storage *log_storage); + void (*duckdb_log_storage_set_write_log_entry)(duckdb_log_storage log_storage, + duckdb_logger_write_log_entry_t function); + void (*duckdb_log_storage_set_extra_data)(duckdb_log_storage log_storage, void *extra_data, + duckdb_delete_callback_t delete_callback); + void (*duckdb_log_storage_set_name)(duckdb_log_storage log_storage, const char *name); + duckdb_state (*duckdb_register_log_storage)(duckdb_database database, duckdb_log_storage log_storage); +#endif + // New functions around the client context #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE idx_t (*duckdb_client_context_get_connection_id)(duckdb_client_context context); @@ -643,6 +732,13 @@ typedef struct { char *(*duckdb_value_to_string)(duckdb_value value); #endif +// New functions around the table description +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); +#endif + // New functions around table function binding #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1093,6 +1189,7 @@ typedef struct { // Version unstable_new_append_functions #define duckdb_appender_create_query duckdb_ext_api.duckdb_appender_create_query #define duckdb_appender_error_data duckdb_ext_api.duckdb_appender_error_data +#define duckdb_appender_clear duckdb_ext_api.duckdb_appender_clear #define duckdb_append_default_to_chunk duckdb_ext_api.duckdb_append_default_to_chunk // Version unstable_new_arrow_functions @@ -1102,6 +1199,69 @@ typedef struct { #define duckdb_data_chunk_from_arrow duckdb_ext_api.duckdb_data_chunk_from_arrow #define duckdb_destroy_arrow_converted_schema duckdb_ext_api.duckdb_destroy_arrow_converted_schema +// Version unstable_new_catalog_interface +#define duckdb_client_context_get_catalog duckdb_ext_api.duckdb_client_context_get_catalog +#define duckdb_catalog_get_type_name duckdb_ext_api.duckdb_catalog_get_type_name +#define duckdb_catalog_get_entry duckdb_ext_api.duckdb_catalog_get_entry +#define duckdb_destroy_catalog duckdb_ext_api.duckdb_destroy_catalog +#define duckdb_catalog_entry_get_type duckdb_ext_api.duckdb_catalog_entry_get_type +#define duckdb_catalog_entry_get_name duckdb_ext_api.duckdb_catalog_entry_get_name +#define duckdb_destroy_catalog_entry duckdb_ext_api.duckdb_destroy_catalog_entry + +// Version unstable_new_config_options_functions +#define duckdb_create_config_option duckdb_ext_api.duckdb_create_config_option +#define duckdb_destroy_config_option duckdb_ext_api.duckdb_destroy_config_option +#define duckdb_config_option_set_name duckdb_ext_api.duckdb_config_option_set_name +#define duckdb_config_option_set_type duckdb_ext_api.duckdb_config_option_set_type +#define duckdb_config_option_set_default_value duckdb_ext_api.duckdb_config_option_set_default_value +#define duckdb_config_option_set_default_scope duckdb_ext_api.duckdb_config_option_set_default_scope +#define duckdb_config_option_set_description duckdb_ext_api.duckdb_config_option_set_description +#define duckdb_register_config_option duckdb_ext_api.duckdb_register_config_option +#define duckdb_client_context_get_config_option duckdb_ext_api.duckdb_client_context_get_config_option + +// Version unstable_new_copy_functions_api +#define duckdb_create_copy_function duckdb_ext_api.duckdb_create_copy_function +#define duckdb_copy_function_set_name duckdb_ext_api.duckdb_copy_function_set_name +#define duckdb_copy_function_set_extra_info duckdb_ext_api.duckdb_copy_function_set_extra_info +#define duckdb_register_copy_function duckdb_ext_api.duckdb_register_copy_function +#define duckdb_destroy_copy_function duckdb_ext_api.duckdb_destroy_copy_function +#define duckdb_copy_function_set_bind duckdb_ext_api.duckdb_copy_function_set_bind +#define duckdb_copy_function_bind_set_error duckdb_ext_api.duckdb_copy_function_bind_set_error +#define duckdb_copy_function_bind_get_extra_info duckdb_ext_api.duckdb_copy_function_bind_get_extra_info +#define duckdb_copy_function_bind_get_client_context duckdb_ext_api.duckdb_copy_function_bind_get_client_context +#define duckdb_copy_function_bind_get_column_count duckdb_ext_api.duckdb_copy_function_bind_get_column_count +#define duckdb_copy_function_bind_get_column_type duckdb_ext_api.duckdb_copy_function_bind_get_column_type +#define duckdb_copy_function_bind_get_options duckdb_ext_api.duckdb_copy_function_bind_get_options +#define duckdb_copy_function_bind_set_bind_data duckdb_ext_api.duckdb_copy_function_bind_set_bind_data +#define duckdb_copy_function_set_global_init duckdb_ext_api.duckdb_copy_function_set_global_init +#define duckdb_copy_function_global_init_set_error duckdb_ext_api.duckdb_copy_function_global_init_set_error +#define duckdb_copy_function_global_init_get_extra_info duckdb_ext_api.duckdb_copy_function_global_init_get_extra_info +#define duckdb_copy_function_global_init_get_client_context \ + duckdb_ext_api.duckdb_copy_function_global_init_get_client_context +#define duckdb_copy_function_global_init_get_bind_data duckdb_ext_api.duckdb_copy_function_global_init_get_bind_data +#define duckdb_copy_function_global_init_get_file_path duckdb_ext_api.duckdb_copy_function_global_init_get_file_path +#define duckdb_copy_function_global_init_set_global_state \ + duckdb_ext_api.duckdb_copy_function_global_init_set_global_state +#define duckdb_copy_function_set_sink duckdb_ext_api.duckdb_copy_function_set_sink +#define duckdb_copy_function_sink_set_error duckdb_ext_api.duckdb_copy_function_sink_set_error +#define duckdb_copy_function_sink_get_extra_info duckdb_ext_api.duckdb_copy_function_sink_get_extra_info +#define duckdb_copy_function_sink_get_client_context duckdb_ext_api.duckdb_copy_function_sink_get_client_context +#define duckdb_copy_function_sink_get_bind_data duckdb_ext_api.duckdb_copy_function_sink_get_bind_data +#define duckdb_copy_function_sink_get_global_state duckdb_ext_api.duckdb_copy_function_sink_get_global_state +#define duckdb_copy_function_set_finalize duckdb_ext_api.duckdb_copy_function_set_finalize +#define duckdb_copy_function_finalize_set_error duckdb_ext_api.duckdb_copy_function_finalize_set_error +#define duckdb_copy_function_finalize_get_extra_info duckdb_ext_api.duckdb_copy_function_finalize_get_extra_info +#define duckdb_copy_function_finalize_get_client_context duckdb_ext_api.duckdb_copy_function_finalize_get_client_context +#define duckdb_copy_function_finalize_get_bind_data duckdb_ext_api.duckdb_copy_function_finalize_get_bind_data +#define duckdb_copy_function_finalize_get_global_state duckdb_ext_api.duckdb_copy_function_finalize_get_global_state +#define duckdb_copy_function_set_copy_from_function duckdb_ext_api.duckdb_copy_function_set_copy_from_function +#define duckdb_table_function_bind_get_result_column_count \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_count +#define duckdb_table_function_bind_get_result_column_name \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_name +#define duckdb_table_function_bind_get_result_column_type \ + duckdb_ext_api.duckdb_table_function_bind_get_result_column_type + // Version unstable_new_error_data_functions #define duckdb_create_error_data duckdb_ext_api.duckdb_create_error_data #define duckdb_destroy_error_data duckdb_ext_api.duckdb_destroy_error_data @@ -1133,6 +1293,14 @@ typedef struct { #define duckdb_file_handle_sync duckdb_ext_api.duckdb_file_handle_sync #define duckdb_file_handle_close duckdb_ext_api.duckdb_file_handle_close +// Version unstable_new_logger_functions +#define duckdb_create_log_storage duckdb_ext_api.duckdb_create_log_storage +#define duckdb_destroy_log_storage duckdb_ext_api.duckdb_destroy_log_storage +#define duckdb_log_storage_set_write_log_entry duckdb_ext_api.duckdb_log_storage_set_write_log_entry +#define duckdb_log_storage_set_extra_data duckdb_ext_api.duckdb_log_storage_set_extra_data +#define duckdb_log_storage_set_name duckdb_ext_api.duckdb_log_storage_set_name +#define duckdb_register_log_storage duckdb_ext_api.duckdb_register_log_storage + // Version unstable_new_open_connect_functions #define duckdb_connection_get_client_context duckdb_ext_api.duckdb_connection_get_client_context #define duckdb_connection_get_arrow_options duckdb_ext_api.duckdb_connection_get_arrow_options @@ -1164,6 +1332,10 @@ typedef struct { // Version unstable_new_string_functions #define duckdb_value_to_string duckdb_ext_api.duckdb_value_to_string +// Version unstable_new_table_description_functions +#define duckdb_table_description_get_column_count duckdb_ext_api.duckdb_table_description_get_column_count +#define duckdb_table_description_get_column_type duckdb_ext_api.duckdb_table_description_get_column_type + // Version unstable_new_table_function_functions #define duckdb_table_function_get_client_context duckdb_ext_api.duckdb_table_function_get_client_context @@ -1222,9 +1394,9 @@ typedef struct { DUCKDB_EXTENSION_EXTERN_C_GUARD_OPEN DUCKDB_CAPI_ENTRY_VISIBILITY DUCKDB_EXTENSION_API bool DUCKDB_EXTENSION_GLUE( \ DUCKDB_EXTENSION_NAME, _init_c_api)(duckdb_extension_info info, struct duckdb_extension_access * access) { \ DUCKDB_EXTENSION_API_INIT(info, access, DUCKDB_EXTENSION_API_VERSION_STRING); \ - duckdb_database *db = access->get_database(info); \ + duckdb_database db = access->get_database(info); \ duckdb_connection conn; \ - if (duckdb_connect(*db, &conn) == DuckDBError) { \ + if (duckdb_connect(db, &conn) == DuckDBError) { \ access->set_error(info, "Failed to open connection to database"); \ return false; \ } \ diff --git a/src/duckdb/src/logging/log_manager.cpp b/src/duckdb/src/logging/log_manager.cpp index 2785386ab..07b31f7af 100644 --- a/src/duckdb/src/logging/log_manager.cpp +++ b/src/duckdb/src/logging/log_manager.cpp @@ -205,7 +205,7 @@ void LogManager::SetEnableStructuredLoggers(vector &enabled_logger_types throw InvalidInputException("Unknown log type: '%s'", enabled_logger_type); } - new_config.enabled_log_types.insert(enabled_logger_type); + new_config.enabled_log_types.insert(lookup->name); min_log_level = MinValue(min_log_level, lookup->level); } @@ -266,6 +266,7 @@ void LogManager::RegisterDefaultLogTypes() { RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); + RegisterLogType(make_uniq()); } } // namespace duckdb diff --git a/src/duckdb/src/logging/log_storage.cpp b/src/duckdb/src/logging/log_storage.cpp index c6733d968..d165ef043 100644 --- a/src/duckdb/src/logging/log_storage.cpp +++ b/src/duckdb/src/logging/log_storage.cpp @@ -14,34 +14,27 @@ #include "duckdb/function/cast/vector_cast_helpers.hpp" #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" +#include "duckdb/common/printer.hpp" #include -#include namespace duckdb { vector LogStorage::GetSchema(LoggingTargetTable table) { switch (table) { - case LoggingTargetTable::ALL_LOGS: - return { - LogicalType::UBIGINT, // context_id - LogicalType::VARCHAR, // scope - LogicalType::UBIGINT, // connection_id - LogicalType::UBIGINT, // transaction_id - LogicalType::UBIGINT, // query_id - LogicalType::UBIGINT, // thread - LogicalType::TIMESTAMP, // timestamp - LogicalType::VARCHAR, // log_type - LogicalType::VARCHAR, // level - LogicalType::VARCHAR, // message - }; + case LoggingTargetTable::ALL_LOGS: { + auto all_logs = GetSchema(LoggingTargetTable::LOG_CONTEXTS); + auto log_entries = GetSchema(LoggingTargetTable::LOG_ENTRIES); + all_logs.insert(all_logs.end(), log_entries.begin() + 1, log_entries.end()); + return all_logs; + } case LoggingTargetTable::LOG_ENTRIES: return { - LogicalType::UBIGINT, // context_id - LogicalType::TIMESTAMP, // timestamp - LogicalType::VARCHAR, // log_type - LogicalType::VARCHAR, // level - LogicalType::VARCHAR, // message + LogicalType::UBIGINT, // context_id + LogicalType::TIMESTAMP_TZ, // timestamp + LogicalType::VARCHAR, // log_type + LogicalType::VARCHAR, // level + LogicalType::VARCHAR, // message }; case LoggingTargetTable::LOG_CONTEXTS: return { @@ -59,11 +52,12 @@ vector LogStorage::GetSchema(LoggingTargetTable table) { vector LogStorage::GetColumnNames(LoggingTargetTable table) { switch (table) { - case LoggingTargetTable::ALL_LOGS: - return { - "context_id", "scope", "connection_id", "transaction_id", "query_id", - "thread_id", "timestamp", "type", "log_level", "message", - }; + case LoggingTargetTable::ALL_LOGS: { + auto all_logs = GetColumnNames(LoggingTargetTable::LOG_CONTEXTS); + auto log_entries = GetColumnNames(LoggingTargetTable::LOG_ENTRIES); + all_logs.insert(all_logs.end(), log_entries.begin() + 1, log_entries.end()); + return all_logs; + } case LoggingTargetTable::LOG_ENTRIES: return {"context_id", "timestamp", "type", "log_level", "message"}; case LoggingTargetTable::LOG_CONTEXTS: @@ -258,8 +252,9 @@ void BufferingLogStorage::UpdateConfigInternal(DatabaseInstance &db, case_insens } void StdOutLogStorage::StdOutWriteStream::WriteData(const_data_ptr_t buffer, idx_t write_size) { - std::cout.write(const_char_ptr_cast(buffer), NumericCast(write_size)); - std::cout.flush(); + string data(const_char_ptr_cast(buffer), NumericCast(write_size)); + Printer::RawPrint(OutputStream::STREAM_STDOUT, data); + Printer::Flush(OutputStream::STREAM_STDOUT); } StdOutLogStorage::StdOutLogStorage(DatabaseInstance &db) : CSVLogStorage(db, false, 1) { @@ -599,7 +594,6 @@ BufferingLogStorage::~BufferingLogStorage() { } static void WriteLoggingContextsToChunk(DataChunk &chunk, const RegisteredLoggingContext &context, idx_t &col) { - auto size = chunk.size(); auto context_id_data = FlatVector::GetData(chunk.data[col++]); diff --git a/src/duckdb/src/logging/log_types.cpp b/src/duckdb/src/logging/log_types.cpp index f78abae59..0fe81b8e0 100644 --- a/src/duckdb/src/logging/log_types.cpp +++ b/src/duckdb/src/logging/log_types.cpp @@ -14,6 +14,7 @@ constexpr LogLevel FileSystemLogType::LEVEL; constexpr LogLevel QueryLogType::LEVEL; constexpr LogLevel HTTPLogType::LEVEL; constexpr LogLevel PhysicalOperatorLogType::LEVEL; +constexpr LogLevel MetricsLogType::LEVEL; constexpr LogLevel CheckpointLogType::LEVEL; //===--------------------------------------------------------------------===// @@ -58,6 +59,8 @@ LogicalType HTTPLogType::GetLogType() { child_list_t request_child_list = { {"type", LogicalType::VARCHAR}, {"url", LogicalType::VARCHAR}, + {"start_time", LogicalType::TIMESTAMP_TZ}, + {"duration_ms", LogicalType::BIGINT}, {"headers", LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)}, }; auto request_type = LogicalType::STRUCT(request_child_list); @@ -90,7 +93,10 @@ string HTTPLogType::ConstructLogMessage(BaseRequest &request, optional_ptr child_list = { + {"metric", LogicalType::VARCHAR}, + {"value", LogicalType::VARCHAR}, + }; + return LogicalType::STRUCT(child_list); +} + +string MetricsLogType::ConstructLogMessage(const MetricType &metric, const Value &value) { + child_list_t child_list = { + {"metric", EnumUtil::ToString(metric)}, + {"value", value.ToString()}, + }; + return Value::STRUCT(std::move(child_list)).ToString(); +} + //===--------------------------------------------------------------------===// // CheckpointLogType //===--------------------------------------------------------------------===// @@ -187,10 +216,38 @@ string CheckpointLogType::ConstructLogMessage(const AttachedDatabase &db, DataTa } string CheckpointLogType::ConstructLogMessage(const AttachedDatabase &db, DataTableInfo &table, idx_t segment_idx, - RowGroup &row_group) { + RowGroup &row_group, idx_t row_group_start) { vector map_keys = {"segment_idx", "start", "count"}; - vector map_values = {to_string(segment_idx), to_string(row_group.start), to_string(row_group.count.load())}; + vector map_values = {to_string(segment_idx), to_string(row_group_start), to_string(row_group.count.load())}; return CreateLog(db, table, "checkpoint", std::move(map_keys), std::move(map_values)); } +//===--------------------------------------------------------------------===// +// TransactionLogType +//===--------------------------------------------------------------------===// +constexpr LogLevel TransactionLogType::LEVEL; + +TransactionLogType::TransactionLogType() : LogType(NAME, LEVEL, GetLogType()) { +} + +LogicalType TransactionLogType::GetLogType() { + child_list_t child_list = { + {"database", LogicalType::VARCHAR}, + {"type", LogicalType::VARCHAR}, + {"transaction_id", LogicalType::UBIGINT}, + }; + return LogicalType::STRUCT(child_list); +} + +string TransactionLogType::ConstructLogMessage(const AttachedDatabase &db, const char *log_type, + transaction_t transaction_id) { + child_list_t child_list = { + {"database", db.name}, + {"type", log_type}, + {"transaction_id", transaction_id == MAX_TRANSACTION_ID ? Value() : Value::UBIGINT(transaction_id)}, + }; + + return Value::STRUCT(std::move(child_list)).ToString(); +} + } // namespace duckdb diff --git a/src/duckdb/src/main/appender.cpp b/src/duckdb/src/main/appender.cpp index bac1c06f1..d02150ae9 100644 --- a/src/duckdb/src/main/appender.cpp +++ b/src/duckdb/src/main/appender.cpp @@ -41,6 +41,7 @@ void BaseAppender::Destructor() { try { Close(); } catch (...) { // NOLINT + // FIXME: Make any log context available here. } } @@ -66,7 +67,7 @@ void BaseAppender::EndRow() { } column = 0; chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() >= STANDARD_VECTOR_SIZE) { + if (ShouldFlushChunk()) { FlushChunk(); } } @@ -338,7 +339,7 @@ void BaseAppender::AppendDataChunk(DataChunk &chunk_p) { // Early-out, if types match. if (chunk_types == appender_types) { collection->Append(chunk_p); - if (collection->Count() >= flush_count) { + if (ShouldFlush()) { Flush(); } return; @@ -371,7 +372,7 @@ void BaseAppender::AppendDataChunk(DataChunk &chunk_p) { } collection->Append(cast_chunk); - if (collection->Count() >= flush_count) { + if (ShouldFlush()) { Flush(); } } @@ -382,7 +383,7 @@ void BaseAppender::FlushChunk() { } collection->Append(chunk); chunk.Reset(); - if (collection->Count() >= flush_count) { + if (ShouldFlush()) { Flush(); } } @@ -422,8 +423,12 @@ void BaseAppender::ClearColumns() { //===--------------------------------------------------------------------===// // Table Appender //===--------------------------------------------------------------------===// -Appender::Appender(Connection &con, const string &database_name, const string &schema_name, const string &table_name) +Appender::Appender(Connection &con, const string &database_name, const string &schema_name, const string &table_name, + const idx_t flush_memory_threshold_p) : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context) { + flush_memory_threshold = (flush_memory_threshold_p == DConstants::INVALID_INDEX) + ? optional_idx::Invalid() + : optional_idx(flush_memory_threshold_p); description = con.TableInfo(database_name, schema_name, table_name); if (!description) { @@ -480,12 +485,13 @@ Appender::Appender(Connection &con, const string &database_name, const string &s collection = make_uniq(allocator, GetActiveTypes()); } -Appender::Appender(Connection &con, const string &schema_name, const string &table_name) - : Appender(con, INVALID_CATALOG, schema_name, table_name) { +Appender::Appender(Connection &con, const string &schema_name, const string &table_name, + const idx_t flush_memory_threshold_p) + : Appender(con, INVALID_CATALOG, schema_name, table_name, flush_memory_threshold_p) { } -Appender::Appender(Connection &con, const string &table_name) - : Appender(con, INVALID_CATALOG, DEFAULT_SCHEMA, table_name) { +Appender::Appender(Connection &con, const string &table_name, const idx_t flush_memory_threshold_p) + : Appender(con, INVALID_CATALOG, DEFAULT_SCHEMA, table_name, flush_memory_threshold_p) { } Appender::~Appender() { @@ -577,12 +583,15 @@ void Appender::ClearColumns() { // Query Appender //===--------------------------------------------------------------------===// QueryAppender::QueryAppender(Connection &con, string query_p, vector types_p, vector names_p, - string table_name_p) + string table_name_p, const idx_t flush_memory_threshold_p) : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context), query(std::move(query_p)), names(std::move(names_p)), table_name(std::move(table_name_p)) { types = std::move(types_p); InitializeChunk(); collection = make_uniq(allocator, GetActiveTypes()); + flush_memory_threshold = (flush_memory_threshold_p == DConstants::INVALID_INDEX) + ? optional_idx::Invalid() + : optional_idx(flush_memory_threshold_p); } QueryAppender::~QueryAppender() { @@ -599,9 +608,13 @@ void QueryAppender::FlushInternal(ColumnDataCollection &collection) { //===--------------------------------------------------------------------===// // Internal Appender //===--------------------------------------------------------------------===// -InternalAppender::InternalAppender(ClientContext &context_p, TableCatalogEntry &table_p, const idx_t flush_count_p) +InternalAppender::InternalAppender(ClientContext &context_p, TableCatalogEntry &table_p, const idx_t flush_count_p, + const idx_t flush_memory_threshold_p) : BaseAppender(Allocator::DefaultAllocator(), table_p.GetTypes(), AppenderType::PHYSICAL, flush_count_p), context(context_p), table(table_p) { + flush_memory_threshold = (flush_memory_threshold_p == DConstants::INVALID_INDEX) + ? optional_idx::Invalid() + : optional_idx(flush_memory_threshold_p); } InternalAppender::~InternalAppender() { @@ -620,4 +633,38 @@ void BaseAppender::Close() { } } +void BaseAppender::Clear() { + chunk.Reset(); + + if (collection) { + collection->Reset(); + } + + column = 0; +} + +bool BaseAppender::ShouldFlushChunk() const { + if (chunk.size() >= STANDARD_VECTOR_SIZE) { + return true; + } + + if (!flush_memory_threshold.IsValid()) { + return false; + } + + return (collection->AllocationSize() >= flush_memory_threshold.GetIndex()); +} + +bool BaseAppender::ShouldFlush() const { + if (collection->Count() >= flush_count) { + return true; + } + + if (!flush_memory_threshold.IsValid()) { + return false; + } + + return (collection->AllocationSize() >= flush_memory_threshold.GetIndex()); +} + } // namespace duckdb diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp index e98070c18..a4ef63035 100644 --- a/src/duckdb/src/main/attached_database.cpp +++ b/src/duckdb/src/main/attached_database.cpp @@ -11,17 +11,23 @@ #include "duckdb/transaction/duck_transaction_manager.hpp" #include "duckdb/main/database_path_and_type.hpp" #include "duckdb/main/valid_checker.hpp" +#include "duckdb/storage/block_allocator.hpp" namespace duckdb { -StoredDatabasePath::StoredDatabasePath(DatabaseFilePathManager &manager, string path_p, const string &name) - : manager(manager), path(std::move(path_p)) { +StoredDatabasePath::StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path_p, + const string &name) + : db_manager(db_manager), manager(manager), path(std::move(path_p)) { } StoredDatabasePath::~StoredDatabasePath() { manager.EraseDatabasePath(path); } +void StoredDatabasePath::OnDetach() { + manager.DetachDatabase(db_manager, path); +} + //===--------------------------------------------------------------------===// // Attach Options //===--------------------------------------------------------------------===// @@ -31,11 +37,9 @@ AttachOptions::AttachOptions(const DBConfigOptions &options) AttachOptions::AttachOptions(const unordered_map &attach_options, const AccessMode default_access_mode) : access_mode(default_access_mode) { - for (auto &entry : attach_options) { if (entry.first == "readonly" || entry.first == "read_only") { // Extract the read access mode. - auto read_only = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); if (read_only) { access_mode = AccessMode::READ_ONLY; @@ -45,6 +49,13 @@ AttachOptions::AttachOptions(const unordered_map &attach_options, continue; } + if (entry.first == "recovery_mode") { + // Extract the recovery mode. + auto mode_str = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); + recovery_mode = EnumUtil::FromString(mode_str); + continue; + } + if (entry.first == "readwrite" || entry.first == "read_write") { // Extract the write access mode. auto read_write = BooleanValue::Get(entry.second.DefaultCastAs(LogicalType::BOOLEAN)); @@ -77,7 +88,6 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType ty : CatalogEntry(CatalogType::DATABASE_ENTRY, type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG, 0), db(db), type(type) { - // This database does not have storage, or uses temporary_objects for in-memory storage. D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); if (type == AttachedDatabaseType::TEMP_DATABASE) { @@ -99,7 +109,9 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, str } else { type = AttachedDatabaseType::READ_WRITE_DATABASE; } + recovery_mode = options.recovery_mode; visibility = options.visibility; + // We create the storage after the catalog to guarantee we allow extensions to instantiate the DuckCatalog. catalog = make_uniq(*this); stored_database_path = std::move(options.stored_database_path); @@ -117,6 +129,7 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, Sto } else { type = AttachedDatabaseType::READ_WRITE_DATABASE; } + recovery_mode = options.recovery_mode; visibility = options.visibility; optional_ptr storage_info = storage_extension->storage_info.get(); @@ -157,6 +170,13 @@ bool AttachedDatabase::NameIsReserved(const string &name) { return name == DEFAULT_SCHEMA || name == TEMP_CATALOG || name == SYSTEM_CATALOG; } +string AttachedDatabase::StoredPath() const { + if (stored_database_path) { + return stored_database_path->path; + } + return string(); +} + static string RemoveQueryParams(const string &name) { auto vec = StringUtil::Split(name, "?"); D_ASSERT(!vec.empty()); @@ -181,7 +201,7 @@ void AttachedDatabase::Initialize(optional_ptr context) { catalog->Initialize(context, false); } if (storage) { - storage->Initialize(QueryContext(context)); + storage->Initialize(context); } } @@ -232,6 +252,9 @@ void AttachedDatabase::OnDetach(ClientContext &context) { if (catalog) { catalog->OnDetach(context); } + if (stored_database_path && visibility != AttachVisibility::HIDDEN) { + stored_database_path->OnDetach(); + } } void AttachedDatabase::Close() { @@ -252,6 +275,12 @@ void AttachedDatabase::Close() { options.wal_action = CheckpointWALAction::DELETE_WAL; storage->CreateCheckpoint(QueryContext(), options); } + } catch (std::exception &ex) { + ErrorData data(ex); + try { + DUCKDB_LOG_ERROR(db, "AttachedDatabase::Close()\t\t" + data.Message()); + } catch (...) { // NOLINT + } } catch (...) { // NOLINT } } @@ -266,10 +295,6 @@ void AttachedDatabase::Close() { catalog.reset(); storage.reset(); stored_database_path.reset(); - - if (Allocator::SupportsFlush()) { - Allocator::FlushAll(); - } } } // namespace duckdb diff --git a/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp b/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp index 3c593374c..e9f949098 100644 --- a/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/batched_buffered_data.cpp @@ -14,9 +14,8 @@ void BatchedBufferedData::BlockSink(const InterruptState &blocked_sink, idx_t ba blocked_sinks.emplace(batch, blocked_sink); } -BatchedBufferedData::BatchedBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::BATCHED, std::move(context)), buffer_byte_count(0), read_queue_byte_count(0), - min_batch(0) { +BatchedBufferedData::BatchedBufferedData(ClientContext &context) + : BufferedData(BufferedData::Type::BATCHED, context), buffer_byte_count(0), read_queue_byte_count(0), min_batch(0) { read_queue_capacity = (idx_t)(static_cast(total_buffer_size) * 0.6); buffer_capacity = (idx_t)(static_cast(total_buffer_size) * 0.4); } diff --git a/src/duckdb/src/main/buffered_data/buffered_data.cpp b/src/duckdb/src/main/buffered_data/buffered_data.cpp index 156539815..0e01df8dc 100644 --- a/src/duckdb/src/main/buffered_data/buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/buffered_data.cpp @@ -4,9 +4,8 @@ namespace duckdb { -BufferedData::BufferedData(Type type, weak_ptr context_p) : type(type), context(std::move(context_p)) { - auto client_context = context.lock(); - auto &config = ClientConfig::GetConfig(*client_context); +BufferedData::BufferedData(Type type, ClientContext &context_p) : type(type), context(context_p.shared_from_this()) { + auto &config = ClientConfig::GetConfig(context_p); total_buffer_size = config.streaming_buffer_size; } diff --git a/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp b/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp index 4b6a3a534..59cde1f43 100644 --- a/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp +++ b/src/duckdb/src/main/buffered_data/simple_buffered_data.cpp @@ -6,8 +6,7 @@ namespace duckdb { -SimpleBufferedData::SimpleBufferedData(weak_ptr context) - : BufferedData(BufferedData::Type::SIMPLE, std::move(context)) { +SimpleBufferedData::SimpleBufferedData(ClientContext &context) : BufferedData(BufferedData::Type::SIMPLE, context) { buffered_count = 0; buffer_size = total_buffer_size; } diff --git a/src/duckdb/src/main/capi/aggregate_function-c.cpp b/src/duckdb/src/main/capi/aggregate_function-c.cpp index 4eb461123..c5c5ddb26 100644 --- a/src/duckdb/src/main/capi/aggregate_function-c.cpp +++ b/src/duckdb/src/main/capi/aggregate_function-c.cpp @@ -193,7 +193,7 @@ void duckdb_aggregate_function_set_return_type(duckdb_aggregate_function functio } auto &aggregate_function = GetCAggregateFunction(function); auto logical_type = reinterpret_cast(type); - aggregate_function.return_type = *logical_type; + aggregate_function.SetReturnType(*logical_type); } void duckdb_aggregate_function_set_functions(duckdb_aggregate_function function, duckdb_aggregate_state_size state_size, @@ -218,7 +218,7 @@ void duckdb_aggregate_function_set_destructor(duckdb_aggregate_function function auto &aggregate_function = GetCAggregateFunction(function); auto &function_info = aggregate_function.function_info->Cast(); function_info.destroy = destroy; - aggregate_function.destructor = duckdb::CAPIAggregateDestructor; + aggregate_function.SetStateDestructorCallback(duckdb::CAPIAggregateDestructor); } duckdb_state duckdb_register_aggregate_function(duckdb_connection connection, duckdb_aggregate_function function) { @@ -237,7 +237,7 @@ void duckdb_aggregate_function_set_special_handling(duckdb_aggregate_function fu return; } auto &aggregate_function = GetCAggregateFunction(function); - aggregate_function.null_handling = duckdb::FunctionNullHandling::SPECIAL_HANDLING; + aggregate_function.SetNullHandling(duckdb::FunctionNullHandling::SPECIAL_HANDLING); } void duckdb_aggregate_function_set_extra_info(duckdb_aggregate_function function, void *extra_info, @@ -311,8 +311,8 @@ duckdb_state duckdb_register_aggregate_function_set(duckdb_connection connection if (aggregate_function.name.empty() || !info.update || !info.combine || !info.finalize) { return DuckDBError; } - if (duckdb::TypeVisitor::Contains(aggregate_function.return_type, duckdb::LogicalTypeId::INVALID) || - duckdb::TypeVisitor::Contains(aggregate_function.return_type, duckdb::LogicalTypeId::ANY)) { + if (duckdb::TypeVisitor::Contains(aggregate_function.GetReturnType(), duckdb::LogicalTypeId::INVALID) || + duckdb::TypeVisitor::Contains(aggregate_function.GetReturnType(), duckdb::LogicalTypeId::ANY)) { return DuckDBError; } for (const auto &argument : aggregate_function.arguments) { diff --git a/src/duckdb/src/main/capi/appender-c.cpp b/src/duckdb/src/main/capi/appender-c.cpp index a54536b20..959db5098 100644 --- a/src/duckdb/src/main/capi/appender-c.cpp +++ b/src/duckdb/src/main/capi/appender-c.cpp @@ -318,6 +318,10 @@ duckdb_state duckdb_appender_flush(duckdb_appender appender_p) { return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Flush(); }); } +duckdb_state duckdb_appender_clear(duckdb_appender appender_p) { + return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Clear(); }); +} + duckdb_state duckdb_appender_close(duckdb_appender appender_p) { return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Close(); }); } diff --git a/src/duckdb/src/main/capi/arrow-c.cpp b/src/duckdb/src/main/capi/arrow-c.cpp index a1bc5391f..f7865bb3f 100644 --- a/src/duckdb/src/main/capi/arrow-c.cpp +++ b/src/duckdb/src/main/capi/arrow-c.cpp @@ -18,10 +18,14 @@ using duckdb::QueryResultType; duckdb_error_data duckdb_to_arrow_schema(duckdb_arrow_options arrow_options, duckdb_logical_type *types, const char **names, idx_t column_count, struct ArrowSchema *out_schema) { - - if (!types || !names || !arrow_options || !out_schema) { + if (!arrow_options || !out_schema) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Invalid argument(s) to duckdb_to_arrow_schema"); + } + // types and names can be nullptr when column_count is 0 + if (column_count > 0 && (!types || !names)) { return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Invalid argument(s) to duckdb_to_arrow_schema"); } + duckdb::vector schema_types; duckdb::vector schema_names; for (idx_t i = 0; i < column_count; i++) { @@ -298,7 +302,6 @@ void duckdb_destroy_arrow(duckdb_arrow *result) { } void duckdb_destroy_arrow_stream(duckdb_arrow_stream *stream_p) { - auto stream = reinterpret_cast(*stream_p); if (!stream) { return; diff --git a/src/duckdb/src/main/capi/cast/from_decimal-c.cpp b/src/duckdb/src/main/capi/cast/from_decimal-c.cpp index e6bc6f98c..a18d0a460 100644 --- a/src/duckdb/src/main/capi/cast/from_decimal-c.cpp +++ b/src/duckdb/src/main/capi/cast/from_decimal-c.cpp @@ -13,23 +13,22 @@ bool CastDecimalCInternal(duckdb_result *source, duckdb_string &result, idx_t co auto scale = duckdb::DecimalType::GetScale(source_type); duckdb::Vector result_vec(duckdb::LogicalType::VARCHAR, false, false); duckdb::string_t result_string; - void *source_address = UnsafeFetchPtr(source, col, row); + auto source_value = UnsafeFetch(source, col, row); switch (source_type.InternalType()) { case duckdb::PhysicalType::INT16: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); + result_string = duckdb::StringCastFromDecimal::Operation(static_cast(source_value), width, + scale, result_vec); break; case duckdb::PhysicalType::INT32: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); + result_string = duckdb::StringCastFromDecimal::Operation(static_cast(source_value), width, + scale, result_vec); break; case duckdb::PhysicalType::INT64: - result_string = duckdb::StringCastFromDecimal::Operation(UnsafeFetchFromPtr(source_address), - width, scale, result_vec); + result_string = duckdb::StringCastFromDecimal::Operation(static_cast(source_value), width, + scale, result_vec); break; case duckdb::PhysicalType::INT128: - result_string = duckdb::StringCastFromDecimal::Operation( - UnsafeFetchFromPtr(source_address), width, scale, result_vec); + result_string = duckdb::StringCastFromDecimal::Operation(source_value, width, scale, result_vec); break; default: throw duckdb::InternalException("Unimplemented internal type for decimal"); @@ -48,10 +47,11 @@ duckdb_hugeint FetchInternals(void *source_address) { template <> duckdb_hugeint FetchInternals(void *source_address) { + const int16_t source_value = static_cast(UnsafeFetchFromPtr(source_address)); duckdb_hugeint result; int16_t intermediate_result; - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + if (!TryCast::Operation(source_value, intermediate_result)) { intermediate_result = FetchDefaultValue::Operation(); } hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); @@ -61,10 +61,11 @@ duckdb_hugeint FetchInternals(void *source_address) { } template <> duckdb_hugeint FetchInternals(void *source_address) { + const int32_t source_value = static_cast(UnsafeFetchFromPtr(source_address)); duckdb_hugeint result; int32_t intermediate_result; - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + if (!TryCast::Operation(source_value, intermediate_result)) { intermediate_result = FetchDefaultValue::Operation(); } hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); @@ -74,10 +75,11 @@ duckdb_hugeint FetchInternals(void *source_address) { } template <> duckdb_hugeint FetchInternals(void *source_address) { + const int64_t source_value = UnsafeFetchFromPtr(source_address); duckdb_hugeint result; int64_t intermediate_result; - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + if (!TryCast::Operation(source_value, intermediate_result)) { intermediate_result = FetchDefaultValue::Operation(); } hugeint_t hugeint_result = Hugeint::Cast(intermediate_result); @@ -87,10 +89,11 @@ duckdb_hugeint FetchInternals(void *source_address) { } template <> duckdb_hugeint FetchInternals(void *source_address) { + const hugeint_t source_value = UnsafeFetchFromPtr(source_address); duckdb_hugeint result; hugeint_t intermediate_result; - if (!TryCast::Operation(UnsafeFetchFromPtr(source_address), intermediate_result)) { + if (!TryCast::Operation(source_value, intermediate_result)) { intermediate_result = FetchDefaultValue::Operation(); } result.lower = intermediate_result.lower; diff --git a/src/duckdb/src/main/capi/cast_function-c.cpp b/src/duckdb/src/main/capi/cast_function-c.cpp index 39a5d90a7..a0b5f243e 100644 --- a/src/duckdb/src/main/capi/cast_function-c.cpp +++ b/src/duckdb/src/main/capi/cast_function-c.cpp @@ -25,7 +25,6 @@ struct CCastFunction { }; struct CCastFunctionUserData { - duckdb_function_info data_ptr = nullptr; duckdb_delete_callback_t delete_callback = nullptr; @@ -56,7 +55,6 @@ struct CCastFunctionData final : public BoundCastData { }; static bool CAPICastFunction(Vector &input, Vector &output, idx_t count, CastParameters ¶meters) { - const auto is_const = input.GetVectorType() == VectorType::CONSTANT_VECTOR; input.Flatten(count); diff --git a/src/duckdb/src/main/capi/catalog-c.cpp b/src/duckdb/src/main/capi/catalog-c.cpp new file mode 100644 index 000000000..59421e909 --- /dev/null +++ b/src/duckdb/src/main/capi/catalog-c.cpp @@ -0,0 +1,166 @@ +#include "duckdb/common/type_visitor.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry.hpp" + +namespace duckdb { +namespace { + +struct CCatalogWrapper { + Catalog &catalog; + string catalog_type; + CCatalogWrapper(Catalog &catalog, const string &catalog_type) : catalog(catalog), catalog_type(catalog_type) { + } +}; + +struct CCatalogEntryWrapper { + CatalogEntry &entry; + CCatalogEntryWrapper(CatalogEntry &entry) : entry(entry) { + } +}; + +CatalogType CatalogTypeFromC(duckdb_catalog_entry_type type) { + switch (type) { + case DUCKDB_CATALOG_ENTRY_TYPE_TABLE: + return CatalogType::TABLE_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_SCHEMA: + return CatalogType::SCHEMA_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_VIEW: + return CatalogType::VIEW_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_INDEX: + return CatalogType::INDEX_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_PREPARED_STATEMENT: + return CatalogType::PREPARED_STATEMENT; + case DUCKDB_CATALOG_ENTRY_TYPE_SEQUENCE: + return CatalogType::SEQUENCE_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_COLLATION: + return CatalogType::COLLATION_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_TYPE: + return CatalogType::TYPE_ENTRY; + case DUCKDB_CATALOG_ENTRY_TYPE_DATABASE: + return CatalogType::DATABASE_ENTRY; + default: + return CatalogType::INVALID; + } +} + +duckdb_catalog_entry_type CatalogTypeToC(CatalogType type) { + switch (type) { + case CatalogType::TABLE_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_TABLE; + case CatalogType::SCHEMA_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_SCHEMA; + case CatalogType::VIEW_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_VIEW; + case CatalogType::INDEX_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_INDEX; + case CatalogType::PREPARED_STATEMENT: + return DUCKDB_CATALOG_ENTRY_TYPE_PREPARED_STATEMENT; + case CatalogType::SEQUENCE_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_SEQUENCE; + case CatalogType::COLLATION_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_COLLATION; + case CatalogType::TYPE_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_TYPE; + case CatalogType::DATABASE_ENTRY: + return DUCKDB_CATALOG_ENTRY_TYPE_DATABASE; + default: + return DUCKDB_CATALOG_ENTRY_TYPE_INVALID; + } +} + +} // namespace +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Catalog +//---------------------------------------------------------------------------------------------------------------------- +duckdb_catalog duckdb_client_context_get_catalog(duckdb_client_context context, const char *name) { + if (!context || !name || strlen(name) == 0) { + return nullptr; + } + + auto &context_ref = *reinterpret_cast(context); + if (!context_ref.context.transaction.HasActiveTransaction()) { + return nullptr; + } + + auto catalog_ptr = duckdb::Catalog::GetCatalogEntry(context_ref.context, name); + + if (!catalog_ptr) { + return nullptr; + } + + auto &catalog_ref = *catalog_ptr; + auto catalog_wrapper = new duckdb::CCatalogWrapper(catalog_ref, catalog_ref.GetCatalogType()); + return reinterpret_cast(catalog_wrapper); +} + +void duckdb_destroy_catalog(duckdb_catalog *catalog) { + if (!catalog || !*catalog) { + return; + } + auto catalog_ptr = reinterpret_cast(*catalog); + delete catalog_ptr; + *catalog = nullptr; +} + +const char *duckdb_catalog_get_type_name(duckdb_catalog catalog) { + if (!catalog) { + return nullptr; + } + auto &catalog_ref = *reinterpret_cast(catalog); + return catalog_ref.catalog_type.c_str(); +} + +duckdb_catalog_entry duckdb_catalog_get_entry(duckdb_catalog catalog, duckdb_client_context context, + duckdb_catalog_entry_type entry_type, const char *schema_name, + const char *entry_name) { + if (!catalog || !context || !schema_name || !entry_name) { + return nullptr; + } + + auto &catalog_ref = *reinterpret_cast(catalog); + auto &context_ref = *reinterpret_cast(context); + + auto entry = catalog_ref.catalog.GetEntry(context_ref.context, duckdb::CatalogTypeFromC(entry_type), schema_name, + entry_name, duckdb::OnEntryNotFound::RETURN_NULL); + + if (!entry) { + return nullptr; + } + + return reinterpret_cast(new duckdb::CCatalogEntryWrapper(*entry)); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Catalog Entry +//---------------------------------------------------------------------------------------------------------------------- + +duckdb_catalog_entry_type duckdb_catalog_entry_get_type(duckdb_catalog_entry entry) { + if (!entry) { + return DUCKDB_CATALOG_ENTRY_TYPE_INVALID; + } + + auto &entry_ref = *reinterpret_cast(entry); + return duckdb::CatalogTypeToC(entry_ref.entry.type); +} + +const char *duckdb_catalog_entry_get_name(duckdb_catalog_entry entry) { + if (!entry) { + return nullptr; + } + auto &entry_ref = *reinterpret_cast(entry); + return entry_ref.entry.name.c_str(); +} + +void duckdb_destroy_catalog_entry(duckdb_catalog_entry *entry) { + if (!entry || !*entry) { + return; + } + auto entry_ptr = reinterpret_cast(*entry); + delete entry_ptr; + *entry = nullptr; +} diff --git a/src/duckdb/src/main/capi/config_options-c.cpp b/src/duckdb/src/main/capi/config_options-c.cpp new file mode 100644 index 000000000..b895245fc --- /dev/null +++ b/src/duckdb/src/main/capi/config_options-c.cpp @@ -0,0 +1,159 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +namespace duckdb { +namespace { + +struct CConfigOption { + string name; + LogicalType type; + Value default_value; + SetScope default_scope = SetScope::SESSION; + string description; +}; + +} // namespace +} // namespace duckdb + +duckdb_config_option duckdb_create_config_option() { + auto coption = new duckdb::CConfigOption(); + return reinterpret_cast(coption); +} + +void duckdb_destroy_config_option(duckdb_config_option *option) { + if (!option || !*option) { + return; + } + auto coption = *reinterpret_cast(option); + delete coption; + + *option = nullptr; +} + +void duckdb_config_option_set_name(duckdb_config_option option, const char *name) { + if (!option || !name) { + return; + } + auto coption = reinterpret_cast(option); + coption->name = name; +} + +void duckdb_config_option_set_type(duckdb_config_option option, duckdb_logical_type type) { + if (!option || !type) { + return; + } + auto coption = reinterpret_cast(option); + coption->type = *reinterpret_cast(type); +} + +void duckdb_config_option_set_default_value(duckdb_config_option option, duckdb_value default_value) { + if (!option || !default_value) { + return; + } + auto coption = reinterpret_cast(option); + auto cvalue = reinterpret_cast(default_value); + + if (coption->type.id() == duckdb::LogicalTypeId::INVALID) { + coption->type = cvalue->type(); + coption->default_value = *cvalue; + return; + } + + if (coption->type != cvalue->type()) { + coption->default_value = cvalue->DefaultCastAs(coption->type, false); + return; + } + + coption->default_value = *cvalue; +} + +void duckdb_config_option_set_default_scope(duckdb_config_option option, duckdb_config_option_scope scope) { + if (!option) { + return; + } + auto coption = reinterpret_cast(option); + switch (scope) { + case DUCKDB_CONFIG_OPTION_SCOPE_LOCAL: + coption->default_scope = duckdb::SetScope::LOCAL; + break; + // Set the option for the current session/connection only. + case DUCKDB_CONFIG_OPTION_SCOPE_SESSION: + coption->default_scope = duckdb::SetScope::SESSION; + break; + // Set the option globally for all sessions/connections. + case DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL: + coption->default_scope = duckdb::SetScope::GLOBAL; + break; + default: + return; + } +} + +void duckdb_config_option_set_description(duckdb_config_option option, const char *description) { + if (!option || !description) { + return; + } + auto coption = reinterpret_cast(option); + coption->description = description; +} + +duckdb_state duckdb_register_config_option(duckdb_connection connection, duckdb_config_option option) { + if (!connection || !option) { + return DuckDBError; + } + + auto conn = reinterpret_cast(connection); + auto coption = reinterpret_cast(option); + + if (coption->name.empty() || coption->type.id() == duckdb::LogicalTypeId::INVALID) { + return DuckDBError; + } + + // TODO: This is not transactional... but theres no easy way to make it so currently. + try { + if (conn->context->db->config.HasExtensionOption(coption->name)) { + // Option already exists + return DuckDBError; + } + conn->context->db->config.AddExtensionOption(coption->name, coption->description, coption->type, + coption->default_value, nullptr, coption->default_scope); + } catch (...) { + return DuckDBError; + } + + return DuckDBSuccess; +} + +duckdb_value duckdb_client_context_get_config_option(duckdb_client_context context, const char *option_name, + duckdb_config_option_scope *out_scope) { + if (!context || !option_name) { + return nullptr; + } + + auto wrapper = reinterpret_cast(context); + auto &ctx = wrapper->context; + + duckdb_config_option_scope res_scope = DUCKDB_CONFIG_OPTION_SCOPE_INVALID; + duckdb::Value *res_value = nullptr; + + duckdb::Value result; + switch (ctx.TryGetCurrentSetting(option_name, result).GetScope()) { + case duckdb::SettingScope::LOCAL: + // This is a bit messy, but "session" is presented as LOCAL on the "settings" side of the API. + res_value = new duckdb::Value(std::move(result)); + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_SESSION; + break; + case duckdb::SettingScope::GLOBAL: + res_value = new duckdb::Value(std::move(result)); + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_GLOBAL; + break; + default: + res_value = nullptr; + res_scope = DUCKDB_CONFIG_OPTION_SCOPE_INVALID; + break; + } + + if (out_scope) { + *out_scope = res_scope; + } + return reinterpret_cast(res_value); +} diff --git a/src/duckdb/src/main/capi/copy_function-c.cpp b/src/duckdb/src/main/capi/copy_function-c.cpp new file mode 100644 index 000000000..b1bb4394b --- /dev/null +++ b/src/duckdb/src/main/capi/copy_function-c.cpp @@ -0,0 +1,821 @@ +#include "duckdb/common/type_visitor.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/function/copy_function.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/capi/capi_internal_table.hpp" +#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" + +//---------------------------------------------------------------------------------------------------------------------- +// Common Copy Function Info +//---------------------------------------------------------------------------------------------------------------------- + +namespace duckdb { +namespace { + +struct CCopyFunctionInfo : public CopyFunctionInfo { + ~CCopyFunctionInfo() override { + if (extra_info && delete_callback) { + delete_callback(extra_info); + } + extra_info = nullptr; + delete_callback = nullptr; + } + + duckdb_copy_function_bind_t bind_to = nullptr; + duckdb_copy_function_global_init_t global_init = nullptr; + duckdb_copy_function_sink_t sink = nullptr; + duckdb_copy_function_finalize_t finalize = nullptr; + + void *extra_info = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +Value MakeValueFromCopyOptions(const case_insensitive_map_t> &options) { + child_list_t option_list; + for (auto &entry : options) { + // Uppercase the option name, to make it simpler for users + auto name = StringUtil::Upper(entry.first); + auto &values = entry.second; + + if (values.empty()) { + // Null! + option_list.emplace_back(std::move(name), Value()); + continue; + } + if (values.size() == 1) { + // Single value + option_list.emplace_back(std::move(name), values[0]); + continue; + } + + auto is_same_type = true; + auto first_type = values[0].type(); + for (auto &val : values) { + if (val.type() != first_type) { + // Different types, cannot unify + is_same_type = false; + break; + } + } + + // Is same type: create a list of that type + if (is_same_type) { + option_list.emplace_back(std::move(name), Value::LIST(first_type, values)); + continue; + } + + // Different types: create an unnamed struct + child_list_t children; + for (auto &val : values) { + children.emplace_back("", val); + } + option_list.emplace_back(std::move(name), Value::STRUCT(children)); + } + + if (option_list.empty()) { + // No options + return Value(); + } + + // Return a struct of all options + return Value::STRUCT(std::move(option_list)); +} + +} // namespace +} // namespace duckdb + +duckdb_copy_function duckdb_create_copy_function() { + auto function = new duckdb::CopyFunction(""); + + function->function_info = duckdb::make_shared_ptr(); + + return reinterpret_cast(function); +} + +void duckdb_copy_function_set_name(duckdb_copy_function copy_function, const char *name) { + if (!copy_function || !name) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + copy_function_ref.name = name; +} + +void duckdb_destroy_copy_function(duckdb_copy_function *copy_function) { + if (copy_function && *copy_function) { + auto function = reinterpret_cast(*copy_function); + delete function; + *copy_function = nullptr; + } +} + +void duckdb_copy_function_set_extra_info(duckdb_copy_function function, void *extra_info, + duckdb_delete_callback_t destroy) { + if (!function) { + return; + } + auto ©_function_ref = *reinterpret_cast(function); + auto &info = copy_function_ref.function_info->Cast(); + info.extra_info = extra_info; + info.delete_callback = destroy; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Bind +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { +struct CCopyToBindInfo : FunctionData { + shared_ptr function_info; + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + + unique_ptr Copy() const override { + throw InternalException("CCopyToBindInfo cannot be copied"); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return bind_data == other.bind_data && delete_callback == other.delete_callback; + } + + ~CCopyToBindInfo() override { + if (bind_data && delete_callback) { + delete_callback(bind_data); + } + bind_data = nullptr; + delete_callback = nullptr; + } +}; + +struct CCopyFunctionToInternalBindInfo { + CCopyFunctionToInternalBindInfo(ClientContext &context, CopyFunctionBindInput &input, + const vector &sql_types, const vector &names, + const CCopyFunctionInfo &function_info) + : context(context), input(input), sql_types(sql_types), names(names), function_info(function_info), + success(true) { + } + + ClientContext &context; + CopyFunctionBindInput &input; + const vector &sql_types; + const vector &names; + const CCopyFunctionInfo &function_info; + bool success; + string error; + + // Supplied by the user + void *bind_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +unique_ptr CCopyToBind(ClientContext &context, CopyFunctionBindInput &input, const vector &names, + const vector &sql_types) { + auto &info = input.function_info->Cast(); + + auto result = make_uniq(); + result->function_info = input.function_info; + + if (info.bind_to) { + // Call the user-defined bind function + CCopyFunctionToInternalBindInfo bind_info(context, input, sql_types, names, info); + info.bind_to(reinterpret_cast(&bind_info)); + + // Pass on user bind data to the result + result->bind_data = bind_info.bind_data; + result->delete_callback = bind_info.delete_callback; + + if (!bind_info.success) { + throw BinderException(bind_info.error); + } + } + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_bind(duckdb_copy_function copy_function, duckdb_copy_function_bind_t bind) { + if (!copy_function || !bind) { + return; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C bind callback + info.bind_to = bind; +} + +void duckdb_copy_function_bind_set_error(duckdb_copy_function_bind_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_bind_get_extra_info(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.function_info.extra_info; +} + +duckdb_client_context duckdb_copy_function_bind_get_client_context(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +idx_t duckdb_copy_function_bind_get_column_count(duckdb_copy_function_bind_info info) { + if (!info) { + return 0; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.sql_types.size(); +} + +duckdb_logical_type duckdb_copy_function_bind_get_column_type(duckdb_copy_function_bind_info info, idx_t col_idx) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + if (col_idx >= info_ref.sql_types.size()) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(info_ref.sql_types[col_idx])); +} + +duckdb_value duckdb_copy_function_bind_get_options(duckdb_copy_function_bind_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &options = info_ref.input.info.options; + + // return as struct of options + auto options_value = duckdb::MakeValueFromCopyOptions(options); + return reinterpret_cast(new duckdb::Value(options_value)); +} + +void duckdb_copy_function_bind_set_bind_data(duckdb_copy_function_bind_info info, void *bind_data, + duckdb_delete_callback_t destructor) { + if (!info) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Store the bind data and destructor + info_ref.bind_data = bind_data; + info_ref.delete_callback = destructor; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Global Initialize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToGlobalState : GlobalFunctionData { + void *global_state = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; + + ~CCopyToGlobalState() override { + if (global_state && delete_callback) { + delete_callback(global_state); + } + global_state = nullptr; + delete_callback = nullptr; + } +}; + +struct CCopyToGlobalInitInfo { + CCopyToGlobalInitInfo(ClientContext &context, FunctionData &bind_data, const string &file_path) + : context(context), bind_data(bind_data), file_path(file_path) { + } + + ClientContext &context; + FunctionData &bind_data; + const string &file_path; + + string error; + bool success = true; + + void *global_state = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +unique_ptr CCopyToGlobalInit(ClientContext &context, FunctionData &bind_data, + const string &file_path) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + auto result = make_uniq(); + + if (function_info.global_init) { + // Call the user-defined global init function + CCopyToGlobalInitInfo global_init_info(context, bind_data, file_path); + function_info.global_init(reinterpret_cast(&global_init_info)); + + // Pass on user global state to the result + result->global_state = global_init_info.global_state; + result->delete_callback = global_init_info.delete_callback; + + if (!global_init_info.success) { + throw InvalidInputException(global_init_info.error); + } + } + + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_global_init(duckdb_copy_function copy_function, duckdb_copy_function_global_init_t init) { + if (!copy_function || !init) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C global init callback + info.global_init = init; +} + +void duckdb_copy_function_global_init_set_error(duckdb_copy_function_global_init_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_global_init_get_extra_info(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_global_init_get_client_context(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_global_init_get_bind_data(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + + return bind_info.bind_data; +} + +void duckdb_copy_function_global_init_set_global_state(duckdb_copy_function_global_init_info info, void *global_state, + duckdb_delete_callback_t destructor) { + if (!info) { + return; + } + auto &info_ref = *reinterpret_cast(info); + info_ref.global_state = global_state; + info_ref.delete_callback = destructor; +} + +const char *duckdb_copy_function_global_init_get_file_path(duckdb_copy_function_global_init_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.file_path.c_str(); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Local Initialize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +unique_ptr CCopyToLocalInit(ExecutionContext &context, FunctionData &bind_data) { + // This isnt exposed to the C-API yet, so we just return empty local function data + return make_uniq(); +} + +} // namespace +} // namespace duckdb +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Sink +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToSinkInfo { + CCopyToSinkInfo(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) + : context(context), bind_data(bind_data), gstate(gstate) { + } + + ClientContext &context; + FunctionData &bind_data; + GlobalFunctionData &gstate; + string error; + bool success = true; +}; + +void CCopyToSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate, DataChunk &input) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + // Flatten input (we dont support compressed execution yet!) + // TODO: Dont flatten! + input.Flatten(); + + CCopyToSinkInfo copy_to_sink_info(context.client, bind_data, gstate); + + // Sink is required! + function_info.sink(reinterpret_cast(©_to_sink_info), + reinterpret_cast(&input)); + + if (!copy_to_sink_info.success) { + throw InvalidInputException(copy_to_sink_info.error); + } +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_sink(duckdb_copy_function copy_function, duckdb_copy_function_sink_t function) { + if (!copy_function || !function) { + return; + } + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C sink callback + info.sink = function; +} + +void duckdb_copy_function_sink_set_error(duckdb_copy_function_sink_info info, const char *error) { + if (!info || !error) { + return; + } + auto &info_ref = *reinterpret_cast(info); + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_sink_get_extra_info(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_sink_get_client_context(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_sink_get_bind_data(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + + return bind_info.bind_data; +} + +void *duckdb_copy_function_sink_get_global_state(duckdb_copy_function_sink_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto &gstate = info_ref.gstate.Cast(); + + return gstate.global_state; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Combine +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +void CCopyToCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, + LocalFunctionData &lstate) { + // Do nothing for now (this isnt exposed to the C-API yet) +} + +} // namespace +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Copy To Finalize +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +struct CCopyToFinalizeInfo { + CCopyToFinalizeInfo(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) + : context(context), bind_data(bind_data), gstate(gstate) { + } + + ClientContext &context; + FunctionData &bind_data; + GlobalFunctionData &gstate; + + string error; + bool success = true; +}; + +void CCopyToFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { + auto &bind_info = bind_data.Cast(); + auto &function_info = bind_info.function_info->Cast(); + + // Finalize is optional + if (function_info.finalize) { + CCopyToFinalizeInfo copy_to_finalize_info(context, bind_data, gstate); + function_info.finalize(reinterpret_cast(©_to_finalize_info)); + + if (!copy_to_finalize_info.success) { + throw InvalidInputException(copy_to_finalize_info.error); + } + } +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_finalize(duckdb_copy_function copy_function, duckdb_copy_function_finalize_t finalize) { + if (!copy_function || !finalize) { + return; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + auto &info = copy_function_ref.function_info->Cast(); + + // Set C finalize callback + info.finalize = finalize; +} + +void duckdb_copy_function_finalize_set_error(duckdb_copy_function_finalize_info info, const char *error) { + if (!info || !error) { + return; + } + + auto &info_ref = *reinterpret_cast(info); + // Set the error message + info_ref.error = error; + info_ref.success = false; +} + +void *duckdb_copy_function_finalize_get_extra_info(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + return info_ref.bind_data.Cast() + .function_info->Cast() + .extra_info; +} + +duckdb_client_context duckdb_copy_function_finalize_get_client_context(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + auto &info_ref = *reinterpret_cast(info); + auto wrapper = new duckdb::CClientContextWrapper(info_ref.context); + return reinterpret_cast(wrapper); +} + +void *duckdb_copy_function_finalize_get_bind_data(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &bind_info = info_ref.bind_data.Cast(); + return bind_info.bind_data; +} + +void *duckdb_copy_function_finalize_get_global_state(duckdb_copy_function_finalize_info info) { + if (!info) { + return nullptr; + } + + auto &info_ref = *reinterpret_cast(info); + auto &gstate = info_ref.gstate.Cast(); + return gstate.global_state; +} + +//---------------------------------------------------------------------------------------------------------------------- +// Copy FROM +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { +namespace { + +unique_ptr CCopyFromBind(ClientContext &context, CopyFromFunctionBindInput &info, + vector &expected_names, vector &expected_types) { + auto &tf_info = info.tf.function_info->Cast(); + auto result = make_uniq(tf_info); + + named_parameter_map_t named_parameters; + + // Turn all options into named parameters + for (auto opt : info.info.options) { + auto param_it = info.tf.named_parameters.find(opt.first); + if (param_it == info.tf.named_parameters.end()) { + // Option not found in the table function's named parameters + throw BinderException("'%s' is not a supported option for copy function '%s'", opt.first.c_str(), + info.tf.name.c_str()); + } + + // Try to convert a list of values into a single Value, either by extracting or unifying into a list + Value param_value; + if (opt.second.empty()) { + continue; + } + if (opt.second.size() == 1) { + param_value = opt.second[0]; + } else { + auto first_type = opt.second[0].type(); + auto is_same_type = true; + for (auto &val : opt.second) { + if (val.type() != first_type) { + is_same_type = false; + break; + } + } + if (is_same_type) { + param_value = Value::LIST(first_type, opt.second); + } else { + throw BinderException("Cannot pass multiple values of different types for copy option '%s'", + opt.first.c_str()); + } + } + + // Assing the option as a named parameter + named_parameters[opt.first] = param_value; + } + + // Also pass file path as a regular parameter + vector parameters; + parameters.push_back(Value(info.info.file_path)); + + // Now bind, using the normal table function bind mechanism + CTableInternalBindInfo bind_info(context, parameters, named_parameters, expected_types, expected_names, *result, + tf_info); + tf_info.bind(reinterpret_cast(&bind_info)); + if (!bind_info.success) { + throw BinderException(bind_info.error); + } + + return std::move(result); +} + +} // namespace +} // namespace duckdb + +void duckdb_copy_function_set_copy_from_function(duckdb_copy_function copy_function, + duckdb_table_function table_function) { + auto ©_function_ref = *reinterpret_cast(copy_function); + if (!copy_function || !table_function) { + return; + } + auto &tf = *reinterpret_cast(table_function); + auto &tf_info = tf.function_info->Cast(); + + if (tf.name.empty()) { + // Take the name from the copy function if not set + tf.name = copy_function_ref.name; + } + + if (!tf_info.bind || !tf_info.init || !tf_info.function) { + return; + } + for (auto it = tf.named_parameters.begin(); it != tf.named_parameters.end(); it++) { + if (duckdb::TypeVisitor::Contains(it->second, duckdb::LogicalTypeId::INVALID)) { + return; + } + } + for (const auto &argument : tf.arguments) { + if (duckdb::TypeVisitor::Contains(argument, duckdb::LogicalTypeId::INVALID)) { + return; + } + } + + // Set the bind callback to mark this as a "copy from" capable function + copy_function_ref.copy_from_bind = duckdb::CCopyFromBind; + copy_function_ref.copy_from_function = tf; +} + +idx_t duckdb_table_function_bind_get_result_column_count(duckdb_bind_info bind_info) { + if (!bind_info) { + return 0; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + return bind_info_ref.return_types.size(); +} + +duckdb_logical_type duckdb_table_function_bind_get_result_column_type(duckdb_bind_info bind_info, idx_t col_idx) { + if (!bind_info) { + return nullptr; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + if (col_idx >= bind_info_ref.return_types.size()) { + return nullptr; + } + return reinterpret_cast(new duckdb::LogicalType(bind_info_ref.return_types[col_idx])); +} + +const char *duckdb_table_function_bind_get_result_column_name(duckdb_bind_info bind_info, idx_t col_idx) { + if (!bind_info) { + return nullptr; + } + auto &bind_info_ref = *reinterpret_cast(bind_info); + if (col_idx >= bind_info_ref.names.size()) { + return nullptr; + } + return bind_info_ref.names[col_idx].c_str(); +} + +//---------------------------------------------------------------------------------------------------------------------- +// Register +//---------------------------------------------------------------------------------------------------------------------- +duckdb_state duckdb_register_copy_function(duckdb_connection connection, duckdb_copy_function copy_function) { + if (!connection || !copy_function) { + return DuckDBError; + } + + auto ©_function_ref = *reinterpret_cast(copy_function); + + // Check that the copy function has a valid name + if (copy_function_ref.name.empty()) { + return DuckDBError; + } + + auto &info = copy_function_ref.function_info->Cast(); + + auto is_copy_to = false; + auto is_copy_from = copy_function_ref.copy_from_bind != nullptr; + + if (info.sink) { + // Set the copy function callbacks + is_copy_to = true; + copy_function_ref.copy_to_bind = duckdb::CCopyToBind; + copy_function_ref.copy_to_initialize_global = duckdb::CCopyToGlobalInit; + copy_function_ref.copy_to_initialize_local = duckdb::CCopyToLocalInit; + copy_function_ref.copy_to_sink = duckdb::CCopyToSink; + copy_function_ref.copy_to_combine = duckdb::CCopyToCombine; + copy_function_ref.copy_to_finalize = duckdb::CCopyToFinalize; + } + + if (!is_copy_to && !is_copy_from) { + // At least one of copy to or copy from must be implemented + return DuckDBError; + } + + auto &conn = *reinterpret_cast(connection); + try { + conn.context->RunFunctionInTransaction([&]() { + auto &catalog = duckdb::Catalog::GetSystemCatalog(*conn.context); + duckdb::CreateCopyFunctionInfo cp_info(copy_function_ref); + cp_info.on_conflict = duckdb::OnCreateConflict::ALTER_ON_CONFLICT; + catalog.CreateCopyFunction(*conn.context, cp_info); + }); + } catch (...) { // LCOV_EXCL_START + return DuckDBError; + } // LCOV_EXCL_STOP + return DuckDBSuccess; +} diff --git a/src/duckdb/src/main/capi/data_chunk-c.cpp b/src/duckdb/src/main/capi/data_chunk-c.cpp index 7274852c4..77f6482ab 100644 --- a/src/duckdb/src/main/capi/data_chunk-c.cpp +++ b/src/duckdb/src/main/capi/data_chunk-c.cpp @@ -167,20 +167,20 @@ idx_t duckdb_list_vector_get_size(duckdb_vector vector) { duckdb_state duckdb_list_vector_set_size(duckdb_vector vector, idx_t size) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::SetListSize(*v, size); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_state duckdb_list_vector_reserve(duckdb_vector vector, idx_t required_capacity) { if (!vector) { - return duckdb_state::DuckDBError; + return DuckDBError; } auto v = reinterpret_cast(vector); duckdb::ListVector::Reserve(*v, required_capacity); - return duckdb_state::DuckDBSuccess; + return DuckDBSuccess; } duckdb_vector duckdb_struct_vector_get_child(duckdb_vector vector, idx_t index) { diff --git a/src/duckdb/src/main/capi/duckdb-c.cpp b/src/duckdb/src/main/capi/duckdb-c.cpp index 344fa265d..3cfedbf61 100644 --- a/src/duckdb/src/main/capi/duckdb-c.cpp +++ b/src/duckdb/src/main/capi/duckdb-c.cpp @@ -41,7 +41,7 @@ duckdb_state duckdb_open_internal(DBInstanceCacheWrapper *cache, const char *pat if (path) { path_str = path; } - wrapper->database = cache->instance_cache->GetOrCreateInstance(path_str, *db_config, true); + wrapper->database = cache->instance_cache->GetOrCreateInstance(path_str, *db_config); } else { wrapper->database = duckdb::make_shared_ptr(path, db_config); } diff --git a/src/duckdb/src/main/capi/file_system-c.cpp b/src/duckdb/src/main/capi/file_system-c.cpp index af82daa6c..e697c363d 100644 --- a/src/duckdb/src/main/capi/file_system-c.cpp +++ b/src/duckdb/src/main/capi/file_system-c.cpp @@ -3,7 +3,6 @@ namespace duckdb { namespace { struct CFileSystem { - FileSystem &fs; ErrorData error_data; diff --git a/src/duckdb/src/main/capi/logging-c.cpp b/src/duckdb/src/main/capi/logging-c.cpp new file mode 100644 index 000000000..47dc4c2f5 --- /dev/null +++ b/src/duckdb/src/main/capi/logging-c.cpp @@ -0,0 +1,134 @@ +#include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/logging/log_storage.hpp" + +namespace duckdb { + +class CallbackLogStorage : public LogStorage { +public: + CallbackLogStorage(const string &name, duckdb_logger_write_log_entry_t write_log_entry_fun, void *extra_data, + duckdb_delete_callback_t delete_callback) + : name(name), write_log_entry_fun(write_log_entry_fun), extra_data(extra_data), + delete_callback(delete_callback) { + } + + ~CallbackLogStorage() override { + if (!extra_data || !delete_callback) { + return; + } + delete_callback(extra_data); + } + + void WriteLogEntry(timestamp_t timestamp, LogLevel level, const string &log_type, const string &log_message, + const RegisteredLoggingContext &context) override { + if (write_log_entry_fun == nullptr) { + return; + } + auto c_timestamp = reinterpret_cast(×tamp); + write_log_entry_fun(extra_data, c_timestamp, EnumUtil::ToChars(level), log_type.c_str(), log_message.c_str()); + }; + + void WriteLogEntries(DataChunk &chunk, const RegisteredLoggingContext &context) override {}; + + void Flush(LoggingTargetTable table) override {}; + + void FlushAll() override {}; + + bool IsEnabled(LoggingTargetTable table) override { + return true; + } + + const string GetStorageName() override { + return name; + } + +private: + const string name; + duckdb_logger_write_log_entry_t write_log_entry_fun; + void *extra_data; + duckdb_delete_callback_t delete_callback; +}; + +struct LogStorageWrapper { + string name; + duckdb_logger_write_log_entry_t write_log_entry = nullptr; + void *extra_data = nullptr; + duckdb_delete_callback_t delete_callback = nullptr; +}; + +} // namespace duckdb + +using duckdb::DatabaseWrapper; +using duckdb::LogStorageWrapper; + +duckdb_log_storage duckdb_create_log_storage() { + auto log_storage_wrapper = new LogStorageWrapper(); + return reinterpret_cast(log_storage_wrapper); +} + +void duckdb_destroy_log_storage(duckdb_log_storage *log_storage) { + if (log_storage && *log_storage) { + auto log_storage_wrapper = reinterpret_cast(*log_storage); + if (log_storage_wrapper->extra_data && log_storage_wrapper->delete_callback) { + log_storage_wrapper->delete_callback(log_storage_wrapper->extra_data); + } + delete log_storage_wrapper; + *log_storage = nullptr; + } +} + +void duckdb_log_storage_set_write_log_entry(duckdb_log_storage log_storage, duckdb_logger_write_log_entry_t function) { + if (!log_storage || !function) { + return; + } + + auto log_storage_wrapper = reinterpret_cast(log_storage); + log_storage_wrapper->write_log_entry = function; +} + +void duckdb_log_storage_set_extra_data(duckdb_log_storage log_storage, void *extra_data, + duckdb_delete_callback_t delete_callback) { + if (!log_storage) { + return; + } + + auto log_storage_wrapper = reinterpret_cast(log_storage); + log_storage_wrapper->extra_data = extra_data; + log_storage_wrapper->delete_callback = delete_callback; +} + +void duckdb_log_storage_set_name(duckdb_log_storage log_storage, const char *name) { + if (!log_storage || !name) { + return; + } + auto log_storage_wrapper = reinterpret_cast(log_storage); + log_storage_wrapper->name = name; +} + +duckdb_state duckdb_register_log_storage(duckdb_database database, duckdb_log_storage log_storage) { + if (!database || !log_storage) { + return DuckDBError; + } + + const auto db_wrapper = reinterpret_cast(database); + auto log_storage_wrapper = reinterpret_cast(log_storage); + if (log_storage_wrapper->name.empty() || log_storage_wrapper->write_log_entry == nullptr) { + return DuckDBError; + } + + const auto &db = *db_wrapper->database; + auto shared_storage_ptr = duckdb::make_shared_ptr( + log_storage_wrapper->name, log_storage_wrapper->write_log_entry, log_storage_wrapper->extra_data, + log_storage_wrapper->delete_callback); + duckdb::shared_ptr storage_ptr = shared_storage_ptr; + + const auto success = db.instance->GetLogManager().RegisterLogStorage(log_storage_wrapper->name, storage_ptr); + if (!success) { + return DuckDBError; + } + + // Avoid leaking memory in case of registration failure because + // we transfer ownership when creating the shared pointer. + log_storage_wrapper->extra_data = nullptr; + log_storage_wrapper->delete_callback = nullptr; + return DuckDBSuccess; +} diff --git a/src/duckdb/src/main/capi/prepared-c.cpp b/src/duckdb/src/main/capi/prepared-c.cpp index 28b2f011f..ac5b638f8 100644 --- a/src/duckdb/src/main/capi/prepared-c.cpp +++ b/src/duckdb/src/main/capi/prepared-c.cpp @@ -88,7 +88,13 @@ duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement) { auto wrapper = reinterpret_cast(prepared_statement); - if (!wrapper || !wrapper->statement || !wrapper->statement->HasError()) { + if (!wrapper) { + return nullptr; + } + if (!wrapper->success) { + return wrapper->error_data.Message().c_str(); + } + if (!wrapper->statement || !wrapper->statement->HasError()) { return nullptr; } return wrapper->statement->error.Message().c_str(); @@ -191,7 +197,7 @@ const char *duckdb_prepared_statement_column_name(duckdb_prepared_statement prep } auto &names = wrapper->statement->GetNames(); - if (col_idx < 0 || col_idx >= names.size()) { + if (col_idx >= names.size()) { return nullptr; } return strdup(names[col_idx].c_str()); @@ -204,7 +210,7 @@ duckdb_logical_type duckdb_prepared_statement_column_logical_type(duckdb_prepare return nullptr; } auto types = wrapper->statement->GetTypes(); - if (col_idx < 0 || col_idx >= types.size()) { + if (col_idx >= types.size()) { return nullptr; } return reinterpret_cast(new LogicalType(types[col_idx])); @@ -229,9 +235,10 @@ duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx return DuckDBError; } if (param_idx <= 0 || param_idx > wrapper->statement->named_param_map.size()) { - wrapper->statement->error = + wrapper->error_data = duckdb::InvalidInputException("Can not bind to parameter number %d, statement only has %d parameter(s)", param_idx, wrapper->statement->named_param_map.size()); + wrapper->success = false; return DuckDBError; } auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); diff --git a/src/duckdb/src/main/capi/profiling_info-c.cpp b/src/duckdb/src/main/capi/profiling_info-c.cpp index 5dc06e22f..c8aa47f2f 100644 --- a/src/duckdb/src/main/capi/profiling_info-c.cpp +++ b/src/duckdb/src/main/capi/profiling_info-c.cpp @@ -3,7 +3,7 @@ using duckdb::Connection; using duckdb::DuckDB; using duckdb::EnumUtil; -using duckdb::MetricsType; +using duckdb::MetricType; using duckdb::optional_ptr; using duckdb::ProfilingNode; @@ -31,7 +31,7 @@ duckdb_value duckdb_profiling_info_get_value(duckdb_profiling_info info, const c } auto &node = *reinterpret_cast(info); auto &profiling_info = node.GetProfilingInfo(); - auto key_enum = EnumUtil::FromString(duckdb::StringUtil::Upper(key)); + auto key_enum = EnumUtil::FromString(duckdb::StringUtil::Upper(key)); if (!profiling_info.Enabled(profiling_info.settings, key_enum)) { return nullptr; } @@ -55,7 +55,7 @@ duckdb_value duckdb_profiling_info_get_metrics(duckdb_profiling_info info) { continue; } - if (key == EnumUtil::ToString(MetricsType::OPERATOR_TYPE)) { + if (key == EnumUtil::ToString(MetricType::OPERATOR_TYPE)) { auto type = duckdb::PhysicalOperatorType(metric.second.GetValue()); metrics_map[key] = EnumUtil::ToString(type); } else { diff --git a/src/duckdb/src/main/capi/scalar_function-c.cpp b/src/duckdb/src/main/capi/scalar_function-c.cpp index 7233b2c20..16c356bdb 100644 --- a/src/duckdb/src/main/capi/scalar_function-c.cpp +++ b/src/duckdb/src/main/capi/scalar_function-c.cpp @@ -9,7 +9,7 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { - +namespace { struct CScalarFunctionInfo : public ScalarFunctionInfo { ~CScalarFunctionInfo() override { if (extra_info && delete_callback) { @@ -28,6 +28,7 @@ struct CScalarFunctionInfo : public ScalarFunctionInfo { struct CScalarFunctionBindData : public FunctionData { explicit CScalarFunctionBindData(CScalarFunctionInfo &info) : info(info) { } + ~CScalarFunctionBindData() override { if (bind_data && delete_callback) { delete_callback(bind_data); @@ -45,6 +46,7 @@ struct CScalarFunctionBindData : public FunctionData { } return std::move(copy); } + bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); return info.extra_info == other.info.extra_info && info.function == other.info.function; @@ -117,7 +119,7 @@ duckdb_function_info ToCScalarFunctionInfo(duckdb::CScalarFunctionInternalFuncti unique_ptr CScalarFunctionBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto &info = bound_function.function_info->Cast(); + auto &info = bound_function.GetExtraFunctionInfo().Cast(); D_ASSERT(info.function); auto result = make_uniq(info); @@ -148,11 +150,12 @@ void CAPIScalarFunction(DataChunk &input, ExpressionState &state, Vector &result if (!function_info.success) { throw InvalidInputException(function_info.error); } - if (all_const && (input.size() == 1 || function.function.stability != FunctionStability::VOLATILE)) { + if (all_const && (input.size() == 1 || function.function.GetStability() != FunctionStability::VOLATILE)) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } } +} // namespace } // namespace duckdb using duckdb::ExpressionWrapper; @@ -164,7 +167,7 @@ using duckdb::GetCScalarFunctionSet; duckdb_scalar_function duckdb_create_scalar_function() { auto function = new duckdb::ScalarFunction("", {}, duckdb::LogicalType::INVALID, duckdb::CAPIScalarFunction, duckdb::CScalarFunctionBind); - function->function_info = duckdb::make_shared_ptr(); + function->SetExtraFunctionInfo(); return reinterpret_cast(function); } @@ -198,7 +201,7 @@ void duckdb_scalar_function_set_special_handling(duckdb_scalar_function function return; } auto &scalar_function = GetCScalarFunction(function); - scalar_function.null_handling = duckdb::FunctionNullHandling::SPECIAL_HANDLING; + scalar_function.SetNullHandling(duckdb::FunctionNullHandling::SPECIAL_HANDLING); } void duckdb_scalar_function_set_volatile(duckdb_scalar_function function) { @@ -206,7 +209,7 @@ void duckdb_scalar_function_set_volatile(duckdb_scalar_function function) { return; } auto &scalar_function = GetCScalarFunction(function); - scalar_function.stability = duckdb::FunctionStability::VOLATILE; + scalar_function.SetVolatile(); } void duckdb_scalar_function_add_parameter(duckdb_scalar_function function, duckdb_logical_type type) { @@ -224,7 +227,7 @@ void duckdb_scalar_function_set_return_type(duckdb_scalar_function function, duc } auto &scalar_function = GetCScalarFunction(function); auto logical_type = reinterpret_cast(type); - scalar_function.return_type = *logical_type; + scalar_function.SetReturnType(*logical_type); } void *duckdb_scalar_function_get_extra_info(duckdb_function_info info) { @@ -302,7 +305,7 @@ void duckdb_scalar_function_set_extra_info(duckdb_scalar_function function, void return; } auto &scalar_function = GetCScalarFunction(function); - auto &info = scalar_function.function_info->Cast(); + auto &info = scalar_function.GetExtraFunctionInfo().Cast(); info.extra_info = reinterpret_cast(extra_info); info.delete_callback = destroy; } @@ -312,7 +315,7 @@ void duckdb_scalar_function_set_bind(duckdb_scalar_function scalar_function, duc return; } auto &sf = GetCScalarFunction(scalar_function); - auto &info = sf.function_info->Cast(); + auto &info = sf.GetExtraFunctionInfo().Cast(); info.bind = bind; } @@ -338,7 +341,7 @@ void duckdb_scalar_function_set_function(duckdb_scalar_function function, duckdb return; } auto &scalar_function = GetCScalarFunction(function); - auto &info = scalar_function.function_info->Cast(); + auto &info = scalar_function.GetExtraFunctionInfo().Cast(); info.function = execute_func; } @@ -385,13 +388,13 @@ duckdb_state duckdb_register_scalar_function_set(duckdb_connection connection, d auto &scalar_function_set = GetCScalarFunctionSet(set); for (idx_t idx = 0; idx < scalar_function_set.Size(); idx++) { auto &scalar_function = scalar_function_set.GetFunctionReferenceByOffset(idx); - auto &info = scalar_function.function_info->Cast(); + auto &info = scalar_function.GetExtraFunctionInfo().Cast(); if (scalar_function.name.empty() || !info.function) { return DuckDBError; } - if (duckdb::TypeVisitor::Contains(scalar_function.return_type, duckdb::LogicalTypeId::INVALID) || - duckdb::TypeVisitor::Contains(scalar_function.return_type, duckdb::LogicalTypeId::ANY)) { + if (duckdb::TypeVisitor::Contains(scalar_function.GetReturnType(), duckdb::LogicalTypeId::INVALID) || + duckdb::TypeVisitor::Contains(scalar_function.GetReturnType(), duckdb::LogicalTypeId::ANY)) { return DuckDBError; } for (const auto &argument : scalar_function.arguments) { diff --git a/src/duckdb/src/main/capi/table_description-c.cpp b/src/duckdb/src/main/capi/table_description-c.cpp index 26624bbfc..cfcd01c43 100644 --- a/src/duckdb/src/main/capi/table_description-c.cpp +++ b/src/duckdb/src/main/capi/table_description-c.cpp @@ -1,5 +1,5 @@ -#include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/main/capi/capi_internal.hpp" using duckdb::Connection; using duckdb::ErrorData; @@ -68,14 +68,14 @@ const char *duckdb_table_description_error(duckdb_table_description table) { return wrapper->error.c_str(); } -duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, idx_t index) { +duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, duckdb::optional_idx index) { if (!wrapper) { return DuckDBError; } auto &table = wrapper->description; - if (index >= table->columns.size()) { - wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", index, - table->columns.size()); + if (index.IsValid() && index.GetIndex() >= table->columns.size()) { + wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", + index.GetIndex(), table->columns.size()); return DuckDBError; } return DuckDBSuccess; @@ -97,6 +97,16 @@ duckdb_state duckdb_column_has_default(duckdb_table_description table_descriptio return DuckDBSuccess; } +idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, duckdb::optional_idx()) == DuckDBError) { + return 0; + } + + auto &table = wrapper->description; + return table->columns.size(); +} + char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index) { auto wrapper = reinterpret_cast(table_description); if (GetTableDescription(wrapper, index) == DuckDBError) { @@ -113,3 +123,16 @@ char *duckdb_table_description_get_column_name(duckdb_table_description table_de return result; } + +duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, idx_t index) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, index) == DuckDBError) { + return nullptr; + } + + auto &table = wrapper->description; + auto &column = table->columns[index]; + + auto logical_type = new duckdb::LogicalType(column.Type()); + return reinterpret_cast(logical_type); +} diff --git a/src/duckdb/src/main/capi/table_function-c.cpp b/src/duckdb/src/main/capi/table_function-c.cpp index deb382ebd..7a6ab6459 100644 --- a/src/duckdb/src/main/capi/table_function-c.cpp +++ b/src/duckdb/src/main/capi/table_function-c.cpp @@ -3,65 +3,16 @@ #include "duckdb/common/types.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/main/capi/capi_internal_table.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/storage/statistics/node_statistics.hpp" namespace duckdb { - +namespace { //===--------------------------------------------------------------------===// // Structures //===--------------------------------------------------------------------===// -struct CTableFunctionInfo : public TableFunctionInfo { - ~CTableFunctionInfo() override { - if (extra_info && delete_callback) { - delete_callback(extra_info); - } - extra_info = nullptr; - delete_callback = nullptr; - } - - duckdb_table_function_bind_t bind = nullptr; - duckdb_table_function_init_t init = nullptr; - duckdb_table_function_init_t local_init = nullptr; - duckdb_table_function_t function = nullptr; - void *extra_info = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; -}; - -struct CTableBindData : public TableFunctionData { - explicit CTableBindData(CTableFunctionInfo &info) : info(info) { - } - ~CTableBindData() override { - if (bind_data && delete_callback) { - delete_callback(bind_data); - } - bind_data = nullptr; - delete_callback = nullptr; - } - - CTableFunctionInfo &info; - void *bind_data = nullptr; - duckdb_delete_callback_t delete_callback = nullptr; - unique_ptr stats; -}; - -struct CTableInternalBindInfo { - CTableInternalBindInfo(ClientContext &context, TableFunctionBindInput &input, vector &return_types, - vector &names, CTableBindData &bind_data, CTableFunctionInfo &function_info) - : context(context), input(input), return_types(return_types), names(names), bind_data(bind_data), - function_info(function_info), success(true) { - } - - ClientContext &context; - TableFunctionBindInput &input; - vector &return_types; - vector &names; - CTableBindData &bind_data; - CTableFunctionInfo &function_info; - bool success; - string error; -}; struct CTableInitData { ~CTableInitData() { @@ -160,7 +111,7 @@ unique_ptr CTableFunctionBind(ClientContext &context, TableFunctio D_ASSERT(info.bind && info.function && info.init); auto result = make_uniq(info); - CTableInternalBindInfo bind_info(context, input, return_types, names, *result, info); + CTableInternalBindInfo bind_info(context, input.inputs, input.named_parameters, return_types, names, *result, info); info.bind(ToCTableFunctionBindInfo(bind_info)); if (!bind_info.success) { throw BinderException(bind_info.error); @@ -216,6 +167,7 @@ void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChun } } +} // namespace } // namespace duckdb //===--------------------------------------------------------------------===// @@ -398,7 +350,7 @@ idx_t duckdb_bind_get_parameter_count(duckdb_bind_info info) { return 0; } auto &bind_info = GetCTableFunctionBindInfo(info); - return bind_info.input.inputs.size(); + return bind_info.parameters.size(); } duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { @@ -406,7 +358,7 @@ duckdb_value duckdb_bind_get_parameter(duckdb_bind_info info, idx_t index) { return nullptr; } auto &bind_info = GetCTableFunctionBindInfo(info); - return reinterpret_cast(new duckdb::Value(bind_info.input.inputs[index])); + return reinterpret_cast(new duckdb::Value(bind_info.parameters[index])); } duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char *name) { @@ -414,8 +366,8 @@ duckdb_value duckdb_bind_get_named_parameter(duckdb_bind_info info, const char * return nullptr; } auto &bind_info = GetCTableFunctionBindInfo(info); - auto t = bind_info.input.named_parameters.find(name); - if (t == bind_info.input.named_parameters.end()) { + auto t = bind_info.named_parameters.find(name); + if (t == bind_info.named_parameters.end()) { return nullptr; } else { return reinterpret_cast(new duckdb::Value(t->second)); diff --git a/src/duckdb/src/main/client_config.cpp b/src/duckdb/src/main/client_config.cpp index 868c80730..e8e7d8d86 100644 --- a/src/duckdb/src/main/client_config.cpp +++ b/src/duckdb/src/main/client_config.cpp @@ -8,8 +8,8 @@ bool ClientConfig::AnyVerification() const { return query_verification_enabled || verify_external || verify_serializer || verify_fetch_row; } -void ClientConfig::SetUserVariable(const string &name, Value value) { - user_variables[name] = std::move(value); +void ClientConfig::SetUserVariable(const String &name, Value value) { + user_variables[name.ToStdString()] = std::move(value); } bool ClientConfig::GetUserVariable(const string &name, Value &result) { diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index f52fbabdd..7ceb8ce59 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -51,6 +51,7 @@ #include "duckdb/logging/log_type.hpp" #include "duckdb/logging/log_manager.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/main/result_set_manager.hpp" namespace duckdb { @@ -333,7 +334,8 @@ unique_ptr ClientContext::FetchResultInternal(ClientContextLock &lo D_ASSERT(active_query->prepared); auto &executor = GetExecutor(); auto &prepared = *active_query->prepared; - bool create_stream_result = prepared.properties.allow_stream_result && pending.allow_stream_result; + bool create_stream_result = + prepared.properties.output_type == QueryResultOutputType::ALLOW_STREAMING && pending.allow_stream_result; unique_ptr result; D_ASSERT(executor.HasResultCollector()); // we have a result collector - fetch the result directly from the result collector @@ -357,19 +359,19 @@ static bool IsExplainAnalyze(SQLStatement *statement) { return explain.explain_type == ExplainType::EXPLAIN_ANALYZE; } -shared_ptr -ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, - unique_ptr statement, - optional_ptr> values) { +shared_ptr ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, + const string &query, + unique_ptr statement, + PendingQueryParameters parameters) { StatementType statement_type = statement->type; auto result = make_shared_ptr(statement_type); auto &profiler = QueryProfiler::Get(*this); profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); - profiler.StartPhase(MetricsType::PLANNER); + profiler.StartPhase(MetricType::PLANNER); Planner logical_planner(*this); - if (values) { - auto ¶meter_values = *values; + if (parameters.parameters) { + auto ¶meter_values = *parameters.parameters; for (auto &value : parameter_values) { logical_planner.parameter_data.emplace(value.first, BoundParameterData(value.second)); } @@ -392,7 +394,7 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st logical_plan->Verify(*this); #endif if (config.enable_optimizer && logical_plan->RequireOptimizer()) { - profiler.StartPhase(MetricsType::ALL_OPTIMIZERS); + profiler.StartPhase(MetricType::ALL_OPTIMIZERS); Optimizer optimizer(*logical_planner.binder, *this); logical_plan = optimizer.Optimize(std::move(logical_plan)); D_ASSERT(logical_plan); @@ -404,7 +406,7 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st } // Convert the logical query plan into a physical query plan. - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER); + profiler.StartPhase(MetricType::PHYSICAL_PLANNER); PhysicalPlanGenerator physical_planner(*this); result->physical_plan = physical_planner.Plan(std::move(logical_plan)); profiler.EndPhase(); @@ -412,10 +414,10 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st return result; } -shared_ptr -ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> values, - PreparedStatementMode mode) { +shared_ptr ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, + unique_ptr statement, + PendingQueryParameters parameters, + PreparedStatementMode mode) { // check if any client context state could request a rebind bool can_request_rebind = false; for (auto &state : registered_state->States()) { @@ -428,7 +430,7 @@ ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &qu // if any registered state can request a rebind we do the binding on a copy first shared_ptr result; try { - result = CreatePreparedStatementInternal(lock, query, statement->Copy(), values); + result = CreatePreparedStatementInternal(lock, query, statement->Copy(), parameters); } catch (std::exception &ex) { ErrorData error(ex); // check if any registered client context state wants to try a rebind @@ -457,7 +459,7 @@ ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &qu // an extension wants to do a rebind - do it once } - return CreatePreparedStatementInternal(lock, query, std::move(statement), values); + return CreatePreparedStatementInternal(lock, query, std::move(statement), parameters); } QueryProgress ClientContext::GetQueryProgress() { @@ -483,8 +485,7 @@ void ClientContext::RebindPreparedStatement(ClientContextLock &lock, const strin "an unbound statement so rebinding cannot be done"); } // catalog was modified: rebind the statement before execution - auto new_prepared = - CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters.parameters); + auto new_prepared = CreatePreparedStatement(lock, query, prepared->unbound_statement->Copy(), parameters); D_ASSERT(new_prepared->properties.bound_all_parameters); new_prepared->properties.parameter_count = prepared->properties.parameter_count; prepared = std::move(new_prepared); @@ -510,7 +511,7 @@ void ClientContext::CheckIfPreparedStatementIsExecutable(PreparedStatementData & "Cannot execute statement of type \"%s\" on database \"%s\" which is attached in read-only mode!", StatementTypeToString(statement.statement_type), modified_database)); } - meta_transaction.ModifyDatabase(*entry); + meta_transaction.ModifyDatabase(*entry, it.second.modifications); } } @@ -539,7 +540,8 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, query_progress.Restart(); } - auto stream_result = parameters.allow_stream_result && statement_data.properties.allow_stream_result; + const auto stream_result = parameters.query_parameters.output_type == QueryResultOutputType::ALLOW_STREAMING && + statement_data.properties.output_type == QueryResultOutputType::ALLOW_STREAMING; // Decide how to get the result collector. get_result_collector_t get_collector = PhysicalResultCollector::GetResultCollector; @@ -547,7 +549,9 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, if (!stream_result && client_config.get_result_collector) { get_collector = client_config.get_result_collector; } - statement_data.is_streaming = stream_result; + statement_data.output_type = + stream_result ? QueryResultOutputType::ALLOW_STREAMING : QueryResultOutputType::FORCE_MATERIALIZED; + statement_data.memory_type = parameters.query_parameters.memory_type; // Get the result collector and initialize the executor. auto &collector = get_collector(*this, statement_data); @@ -707,7 +711,8 @@ unique_ptr ClientContext::PrepareInternal(ClientContextLock & shared_ptr prepared_data; auto unbound_statement = statement->Copy(); RunFunctionInTransactionInternal( - lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement)); }, false); + lock, [&]() { prepared_data = CreatePreparedStatement(lock, statement_query, std::move(statement), {}); }, + false); prepared_data->unbound_statement = std::move(unbound_statement); return make_uniq(shared_from_this(), std::move(prepared_data), std::move(statement_query), std::move(named_param_map)); @@ -775,10 +780,10 @@ unique_ptr ClientContext::Execute(const string &query, shared_ptr

ClientContext::Execute(const string &query, shared_ptr &prepared, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters query_parameters) { PendingQueryParameters parameters; parameters.parameters = &values; - parameters.allow_stream_result = allow_stream_result; + parameters.query_parameters = query_parameters; return Execute(query, prepared, parameters); } @@ -790,7 +795,7 @@ unique_ptr ClientContext::PendingStatementInternal(ClientCon PreparedStatement::VerifyParameters(*parameters.parameters, statement->named_param_map); } - auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters.parameters, + auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters, PreparedStatementMode::PREPARE_AND_EXECUTE); idx_t parameter_count = !parameters.parameters ? 0 : parameters.parameters->size(); @@ -807,13 +812,9 @@ unique_ptr ClientContext::PendingStatementInternal(ClientCon return PendingPreparedStatementInternal(lock, std::move(prepared), parameters); } -unique_ptr -ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, unique_ptr statement, - bool allow_stream_result, - optional_ptr> params, bool verify) { - PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; - parameters.parameters = params; +unique_ptr ClientContext::RunStatementInternal(ClientContextLock &lock, const string &query, + unique_ptr statement, + const PendingQueryParameters ¶meters, bool verify) { auto pending = PendingQueryInternal(lock, std::move(statement), parameters, verify); if (pending->HasError()) { return ErrorResult(pending->GetErrorObject()); @@ -846,7 +847,7 @@ unique_ptr ClientContext::PendingStatementOrPreparedStatemen // in case this is a select query, we verify the original statement ErrorData error; try { - error = VerifyQuery(lock, query, std::move(statement), parameters.parameters); + error = VerifyQuery(lock, query, std::move(statement), parameters); } catch (std::exception &ex) { error = ErrorData(ex); } @@ -958,15 +959,15 @@ void ClientContext::LogQueryInternal(ClientContextLock &, const string &query) { client_data->log_query_writer->Sync(); } -unique_ptr ClientContext::Query(unique_ptr statement, bool allow_stream_result) { - auto pending_query = PendingQuery(std::move(statement), allow_stream_result); +unique_ptr ClientContext::Query(unique_ptr statement, QueryParameters parameters) { + auto pending_query = PendingQuery(std::move(statement), parameters); if (pending_query->HasError()) { return ErrorResult(pending_query->GetErrorObject()); } return pending_query->Execute(); } -unique_ptr ClientContext::Query(const string &query, bool allow_stream_result) { +unique_ptr ClientContext::Query(const string &query, QueryParameters query_parameters) { auto lock = LockContext(); vector> statements; @@ -991,7 +992,10 @@ unique_ptr ClientContext::Query(const string &query, bool allow_str auto &statement = statements[i]; bool is_last_statement = i + 1 == statements.size(); PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result && is_last_statement; + parameters.query_parameters = query_parameters; + if (!is_last_statement) { + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + } auto pending_query = PendingQueryInternal(*lock, std::move(statement), parameters); auto has_result = pending_query->properties.return_type == StatementReturnType::QUERY_RESULT; unique_ptr current_result; @@ -1032,20 +1036,27 @@ vector> ClientContext::ParseStatements(ClientContextLoc return ParseStatementsInternal(lock, query); } -unique_ptr ClientContext::PendingQuery(const string &query, bool allow_stream_result) { +unique_ptr ClientContext::PendingQuery(const string &query, QueryParameters parameters) { case_insensitive_map_t empty_param_list; - return PendingQuery(query, empty_param_list, allow_stream_result); + return PendingQuery(query, empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(unique_ptr statement, - bool allow_stream_result) { + QueryParameters parameters) { case_insensitive_map_t empty_param_list; - return PendingQuery(std::move(statement), empty_param_list, allow_stream_result); + return PendingQuery(std::move(statement), empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(const string &query, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters parameters) { + PendingQueryParameters params; + params.parameters = values; + params.query_parameters = parameters; + return PendingQuery(query, params); +} + +unique_ptr ClientContext::PendingQuery(const string &query, PendingQueryParameters parameters) { auto lock = LockContext(); try { InitialCleanup(*lock); @@ -1058,11 +1069,7 @@ unique_ptr ClientContext::PendingQuery(const string &query, throw InvalidInputException("Cannot prepare multiple statements at once!"); } - PendingQueryParameters params; - params.allow_stream_result = allow_stream_result; - params.parameters = values; - - return PendingQueryInternal(*lock, std::move(statements[0]), params, true); + return PendingQueryInternal(*lock, std::move(statements[0]), parameters, true); } catch (std::exception &ex) { ErrorData error(ex); ProcessError(error, query); @@ -1072,14 +1079,14 @@ unique_ptr ClientContext::PendingQuery(const string &query, unique_ptr ClientContext::PendingQuery(unique_ptr statement, case_insensitive_map_t &values, - bool allow_stream_result) { + QueryParameters parameters) { auto lock = LockContext(); auto query = statement->query; try { InitialCleanup(*lock); PendingQueryParameters params; - params.allow_stream_result = allow_stream_result; + params.query_parameters = parameters; params.parameters = values; return PendingQueryInternal(*lock, std::move(statement), params, true); @@ -1109,6 +1116,10 @@ void ClientContext::Interrupt() { interrupted = true; } +bool ClientContext::IsInterrupted() const { + return interrupted; +} + void ClientContext::CancelTransaction() { auto lock = LockContext(); InitialCleanup(*lock); @@ -1335,7 +1346,7 @@ unordered_set ClientContext::GetTableNames(const string &query, const bo unique_ptr ClientContext::PendingQueryInternal(ClientContextLock &lock, const shared_ptr &relation, - bool allow_stream_result) { + QueryParameters query_parameters) { InitialCleanup(lock); string query; @@ -1347,20 +1358,23 @@ unique_ptr ClientContext::PendingQueryInternal(ClientContext // verify read only statements by running a select statement auto select = make_uniq(); select->node = relation->GetQueryNode(); - RunStatementInternal(lock, query, std::move(select), false, nullptr); + PendingQueryParameters parameters; + parameters.query_parameters = query_parameters; + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + RunStatementInternal(lock, query, std::move(select), parameters); } } auto relation_stmt = make_uniq(relation); PendingQueryParameters parameters; - parameters.allow_stream_result = allow_stream_result; + parameters.query_parameters = query_parameters; return PendingQueryInternal(lock, std::move(relation_stmt), parameters); } unique_ptr ClientContext::PendingQuery(const shared_ptr &relation, - bool allow_stream_result) { + QueryParameters query_parameters) { auto lock = LockContext(); - return PendingQueryInternal(*lock, relation, allow_stream_result); + return PendingQueryInternal(*lock, relation, query_parameters); } unique_ptr ClientContext::Execute(const shared_ptr &relation) { @@ -1443,6 +1457,7 @@ ParserOptions ClientContext::GetParserOptions() const { options.integer_division = DBConfig::GetSetting(*this); options.max_expression_depth = client_config.max_expression_depth; options.extensions = &DBConfig::GetConfig(*this).parser_extensions; + options.parser_override_setting = DBConfig::GetConfig(*this).options.allow_parser_override_extension; return options; } diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp index 1348c0b09..a9abf18b0 100644 --- a/src/duckdb/src/main/client_data.cpp +++ b/src/duckdb/src/main/client_data.cpp @@ -35,28 +35,58 @@ class ClientFileSystem : public OpenerFileSystem { //! ClientBufferManager wraps the buffer manager to optionally forward the client context. class ClientBufferManager : public BufferManager { public: - explicit ClientBufferManager(BufferManager &buffer_manager_p) : buffer_manager(buffer_manager_p) { + explicit ClientBufferManager(ClientContext &context_p, BufferManager &buffer_manager_p) + : context(context_p), buffer_manager(buffer_manager_p) { } public: shared_ptr AllocateTemporaryMemory(MemoryTag tag, idx_t block_size, bool can_destroy = true) override { - return buffer_manager.AllocateTemporaryMemory(tag, block_size, can_destroy); + auto result = buffer_manager.AllocateTemporaryMemory(tag, block_size, can_destroy); + // Track allocation based on actual allocated size from the handle + if (result) { + TrackMemoryAllocation(result->GetMemoryUsage()); + } + return result; } shared_ptr AllocateMemory(MemoryTag tag, BlockManager *block_manager, bool can_destroy = true) override { - return buffer_manager.AllocateMemory(tag, block_manager, can_destroy); + auto result = buffer_manager.AllocateMemory(tag, block_manager, can_destroy); + // Track allocation based on actual allocated size from the handle + if (result) { + TrackMemoryAllocation(result->GetMemoryUsage()); + } + return result; } BufferHandle Allocate(MemoryTag tag, idx_t block_size, bool can_destroy = true) override { - return buffer_manager.Allocate(tag, block_size, can_destroy); + auto result = buffer_manager.Allocate(tag, block_size, can_destroy); + // Track allocation based on actual allocated size from the handle + if (result.GetBlockHandle()) { + TrackMemoryAllocation(result.GetBlockHandle()->GetMemoryUsage()); + } + return result; } BufferHandle Allocate(MemoryTag tag, BlockManager *block_manager, bool can_destroy = true) override { - return buffer_manager.Allocate(tag, block_manager, can_destroy); + auto result = buffer_manager.Allocate(tag, block_manager, can_destroy); + // Track allocation based on actual allocated size from the handle + if (result.GetBlockHandle()) { + TrackMemoryAllocation(result.GetBlockHandle()->GetMemoryUsage()); + } + return result; } void ReAllocate(shared_ptr &handle, idx_t block_size) override { - return buffer_manager.ReAllocate(handle, block_size); + // Track the difference in size (new size - old size) + idx_t old_size = handle->GetMemoryUsage(); + buffer_manager.ReAllocate(handle, block_size); + idx_t new_size = handle->GetMemoryUsage(); + if (new_size > old_size) { + TrackMemoryAllocation(new_size - old_size); + } } BufferHandle Pin(shared_ptr &handle) override { - return buffer_manager.Pin(handle); + return Pin(QueryContext(), handle); + } + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) override { + return buffer_manager.Pin(context, handle); } void Prefetch(vector> &handles) override { return buffer_manager.Prefetch(handles); @@ -88,13 +118,19 @@ class ClientBufferManager : public BufferManager { } shared_ptr RegisterTransientMemory(const idx_t size, BlockManager &block_manager) override { - return buffer_manager.RegisterTransientMemory(size, block_manager); + auto result = buffer_manager.RegisterTransientMemory(size, block_manager); + TrackMemoryAllocation(size); + return result; } shared_ptr RegisterSmallMemory(const idx_t size) override { - return buffer_manager.RegisterSmallMemory(size); + auto result = buffer_manager.RegisterSmallMemory(size); + TrackMemoryAllocation(size); + return result; } shared_ptr RegisterSmallMemory(MemoryTag tag, const idx_t size) override { - return buffer_manager.RegisterSmallMemory(tag, size); + auto result = buffer_manager.RegisterSmallMemory(tag, size); + TrackMemoryAllocation(size); + return result; } Allocator &GetBufferAllocator() override { @@ -116,6 +152,9 @@ class ClientBufferManager : public BufferManager { return buffer_manager.SetSwapLimit(limit); } + BlockManager &GetTemporaryBlockManager() override { + return buffer_manager.GetTemporaryBlockManager(); + } vector GetTemporaryFiles() override { return buffer_manager.GetTemporaryFiles(); } @@ -164,6 +203,16 @@ class ClientBufferManager : public BufferManager { } private: + void TrackMemoryAllocation(idx_t size) const { + if (size > 0) { + auto &profiler = QueryProfiler::Get(context); + // Track allocations even if profiler isn't running yet - they'll be included when the query starts + // AddToCounter already checks IsEnabled(), so we don't need to check here + profiler.AddToCounter(MetricType::TOTAL_MEMORY_ALLOCATED, size); + } + } + + ClientContext &context; BufferManager &buffer_manager; }; @@ -176,7 +225,7 @@ ClientData::ClientData(ClientContext &context) : catalog_search_path(make_uniq(); file_opener = make_uniq(context); client_file_system = make_uniq(context); - client_buffer_manager = make_uniq(db.GetBufferManager()); + client_buffer_manager = make_uniq(context, db.GetBufferManager()); temporary_objects->Initialize(); } diff --git a/src/duckdb/src/main/client_verify.cpp b/src/duckdb/src/main/client_verify.cpp index 05b190b07..2287c8b9e 100644 --- a/src/duckdb/src/main/client_verify.cpp +++ b/src/duckdb/src/main/client_verify.cpp @@ -22,7 +22,7 @@ static void ThrowIfExceptionIsInternal(StatementVerifier &verifier) { } ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &query, unique_ptr statement, - optional_ptr> parameters) { + PendingQueryParameters query_parameters) { D_ASSERT(statement->type == StatementType::SELECT_STATEMENT); // Aggressive query verification @@ -32,6 +32,10 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool run_slow_verifiers = false; #endif + auto parameters = query_parameters.parameters; + query_parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + query_parameters.query_parameters.memory_type = QueryResultMemoryType::IN_MEMORY; + // The purpose of this function is to test correctness of otherwise hard to test features: // Copy() of statements and expressions // Serialize()/Deserialize() of expressions @@ -98,7 +102,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool any_failed = original->Run(*this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (!any_failed) { statement_verifiers.emplace_back( @@ -109,7 +113,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer bool failed = verifier->Run(*this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); any_failed = any_failed || failed; } @@ -120,7 +124,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer *this, query, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (!failed) { // PreparedStatementVerifier fails if it runs into a ParameterNotAllowedException, which is OK @@ -155,7 +159,7 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer *this, explain_q, [&](const string &q, unique_ptr s, optional_ptr> params) { - return RunStatementInternal(lock, q, std::move(s), false, params, false); + return RunStatementInternal(lock, q, std::move(s), query_parameters, false); }); if (explain_failed) { // LCOV_EXCL_START @@ -173,7 +177,8 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer // test with a random width config.max_width = random.NextRandomInteger() % 500; BoxRenderer renderer(config); - renderer.ToString(*this, original->materialized_result->names, original->materialized_result->Collection()); + auto pinned_result_set = original->materialized_result->Pin(); + renderer.ToString(*this, original->materialized_result->names, pinned_result_set->collection); #endif } diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 78b174902..a73c83a8f 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -8,6 +8,7 @@ #include "duckdb/main/settings.hpp" #include "duckdb/storage/storage_extension.hpp" #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #ifndef DUCKDB_NO_THREADS #include "duckdb/common/thread.hpp" @@ -63,6 +64,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(AllocatorFlushThresholdSetting), DUCKDB_GLOBAL(AllowCommunityExtensionsSetting), DUCKDB_SETTING(AllowExtensionsMetadataMismatchSetting), + DUCKDB_GLOBAL(AllowParserOverrideExtensionSetting), DUCKDB_GLOBAL(AllowPersistentSecretsSetting), DUCKDB_GLOBAL(AllowUnredactedSecretsSetting), DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), @@ -76,6 +78,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(AutoinstallExtensionRepositorySetting), DUCKDB_GLOBAL(AutoinstallKnownExtensionsSetting), DUCKDB_GLOBAL(AutoloadKnownExtensionsSetting), + DUCKDB_GLOBAL(BlockAllocatorMemorySetting), DUCKDB_SETTING(CatalogErrorMaxSchemasSetting), DUCKDB_GLOBAL(CheckpointThresholdSetting), DUCKDB_GLOBAL(CustomExtensionRepositorySetting), @@ -83,9 +86,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(CustomUserAgentSetting), DUCKDB_SETTING(DebugAsofIejoinSetting), DUCKDB_SETTING_CALLBACK(DebugCheckpointAbortSetting), + DUCKDB_SETTING(DebugCheckpointSleepMsSetting), DUCKDB_LOCAL(DebugForceExternalSetting), DUCKDB_SETTING(DebugForceNoCrossProductSetting), + DUCKDB_SETTING_CALLBACK(DebugPhysicalTableScanExecutionStrategySetting), DUCKDB_SETTING(DebugSkipCheckpointOnCommitSetting), + DUCKDB_SETTING(DebugVerifyBlocksSetting), DUCKDB_SETTING_CALLBACK(DebugVerifyVectorSetting), DUCKDB_SETTING_CALLBACK(DebugWindowModeSetting), DUCKDB_GLOBAL(DefaultBlockSizeSetting), @@ -117,11 +123,13 @@ static const ConfigurationOption internal_options[] = { DUCKDB_LOCAL(ErrorsAsJSONSetting), DUCKDB_SETTING(ExperimentalMetadataReuseSetting), DUCKDB_LOCAL(ExplainOutputSetting), + DUCKDB_GLOBAL(ExtensionDirectoriesSetting), DUCKDB_GLOBAL(ExtensionDirectorySetting), DUCKDB_GLOBAL(ExternalThreadsSetting), DUCKDB_LOCAL(FileSearchPathSetting), DUCKDB_GLOBAL(ForceBitpackingModeSetting), DUCKDB_GLOBAL(ForceCompressionSetting), + DUCKDB_GLOBAL(ForceVariantShredding), DUCKDB_LOCAL(HomeDirectorySetting), DUCKDB_LOCAL(HTTPLoggingOutputSetting), DUCKDB_GLOBAL(HTTPProxySetting), @@ -175,16 +183,17 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(TempFileEncryptionSetting), DUCKDB_GLOBAL(ThreadsSetting), DUCKDB_GLOBAL(UsernameSetting), + DUCKDB_GLOBAL(VariantMinimumShreddingSize), DUCKDB_SETTING(WriteBufferRowGroupCountSetting), DUCKDB_GLOBAL(ZstdMinStringLengthSetting), FINAL_SETTING}; -static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 83), - DUCKDB_SETTING_ALIAS("null_order", 33), - DUCKDB_SETTING_ALIAS("profiling_output", 102), - DUCKDB_SETTING_ALIAS("user", 117), - DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 20), - DUCKDB_SETTING_ALIAS("worker_threads", 116), +static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 90), + DUCKDB_SETTING_ALIAS("null_order", 38), + DUCKDB_SETTING_ALIAS("profiling_output", 109), + DUCKDB_SETTING_ALIAS("user", 124), + DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 22), + DUCKDB_SETTING_ALIAS("worker_threads", 123), FINAL_ALIAS}; vector DBConfig::GetOptions() { @@ -326,9 +335,9 @@ void DBConfig::ResetOption(optional_ptr db, const Configuratio option.reset_global(db.get(), *this); } -void DBConfig::SetOption(const string &name, Value value) { +void DBConfig::SetOption(const String &name, Value value) { lock_guard l(config_lock); - options.set_variables[name] = std::move(value); + options.set_variables[name.ToStdString()] = std::move(value); } void DBConfig::ResetOption(const String &name) { @@ -440,8 +449,14 @@ LogicalType DBConfig::ParseLogicalType(const string &type) { return type_id; } +bool DBConfig::HasExtensionOption(const string &name) { + lock_guard l(config_lock); + return extension_parameters.find(name) != extension_parameters.end(); +} + void DBConfig::AddExtensionOption(const string &name, string description, LogicalType parameter, const Value &default_value, set_option_callback_t function, SetScope default_scope) { + lock_guard l(config_lock); extension_parameters.insert(make_pair( name, ExtensionOption(std::move(description), std::move(parameter), function, default_value, default_scope))); // copy over unrecognized options, if they match the new extension option @@ -517,8 +532,7 @@ void DBConfig::CheckLock(const String &name) { return; } // not allowed! - throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", - name.ToStdString()); + throw InvalidInputException("Cannot change configuration option \"%s\" - the configuration has been locked", name); } idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { diff --git a/src/duckdb/src/main/connection.cpp b/src/duckdb/src/main/connection.cpp index e561a3cb9..5546594a8 100644 --- a/src/duckdb/src/main/connection.cpp +++ b/src/duckdb/src/main/connection.cpp @@ -19,30 +19,22 @@ namespace duckdb { Connection::Connection(DatabaseInstance &database) - : context(make_shared_ptr(database.shared_from_this())), warning_cb(nullptr) { + : context(make_shared_ptr(database.shared_from_this())) { auto &connection_manager = ConnectionManager::Get(database); connection_manager.AddConnection(*context); connection_manager.AssignConnectionId(*this); - -#ifdef DEBUG - EnableProfiling(); - context->config.emit_profiler_output = false; -#endif } Connection::Connection(DuckDB &database) : Connection(*database.instance) { - // Initialization of warning_cb happens in the other constructor } -Connection::Connection(Connection &&other) noexcept : warning_cb(nullptr) { +Connection::Connection(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); } Connection &Connection::operator=(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); return *this; } @@ -98,40 +90,51 @@ void Connection::ForceParallelism() { ClientConfig::GetConfig(*context).verify_parallelism = true; } -unique_ptr Connection::SendQuery(const string &query) { - return context->Query(query, true); +unique_ptr Connection::SendQuery(const string &query, QueryParameters query_parameters) { + return context->Query(query, query_parameters); +} + +unique_ptr Connection::SendQuery(unique_ptr statement, QueryParameters query_parameters) { + return context->Query(std::move(statement), query_parameters); } unique_ptr Connection::Query(const string &query) { - auto result = context->Query(query, false); + QueryParameters query_parameters; + query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + auto result = context->Query(query, query_parameters); D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); return unique_ptr_cast(std::move(result)); } -unique_ptr Connection::Query(unique_ptr statement) { - auto result = context->Query(std::move(statement), false); +unique_ptr Connection::Query(unique_ptr statement, + QueryResultMemoryType memory_type) { + QueryParameters query_parameters; + query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + query_parameters.memory_type = memory_type; + auto result = context->Query(std::move(statement), query_parameters); D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT); return unique_ptr_cast(std::move(result)); } -unique_ptr Connection::PendingQuery(const string &query, bool allow_stream_result) { - return context->PendingQuery(query, allow_stream_result); +unique_ptr Connection::PendingQuery(const string &query, QueryParameters query_parameters) { + return context->PendingQuery(query, query_parameters); } -unique_ptr Connection::PendingQuery(unique_ptr statement, bool allow_stream_result) { - return context->PendingQuery(std::move(statement), allow_stream_result); +unique_ptr Connection::PendingQuery(unique_ptr statement, + QueryParameters query_parameters) { + return context->PendingQuery(std::move(statement), query_parameters); } unique_ptr Connection::PendingQuery(const string &query, case_insensitive_map_t &named_values, - bool allow_stream_result) { - return context->PendingQuery(query, named_values, allow_stream_result); + QueryParameters query_parameters) { + return context->PendingQuery(query, named_values, query_parameters); } unique_ptr Connection::PendingQuery(unique_ptr statement, case_insensitive_map_t &named_values, - bool allow_stream_result) { - return context->PendingQuery(std::move(statement), named_values, allow_stream_result); + QueryParameters query_parameters) { + return context->PendingQuery(std::move(statement), named_values, query_parameters); } static case_insensitive_map_t ConvertParamListToMap(vector ¶m_list) { @@ -144,15 +147,19 @@ static case_insensitive_map_t ConvertParamListToMap(vector Connection::PendingQuery(const string &query, vector &values, - bool allow_stream_result) { + QueryParameters query_parameters) { auto named_params = ConvertParamListToMap(values); - return context->PendingQuery(query, named_params, allow_stream_result); + return context->PendingQuery(query, named_params, query_parameters); } unique_ptr Connection::PendingQuery(unique_ptr statement, vector &values, - bool allow_stream_result) { + QueryParameters query_parameters) { auto named_params = ConvertParamListToMap(values); - return context->PendingQuery(std::move(statement), named_params, allow_stream_result); + return context->PendingQuery(std::move(statement), named_params, query_parameters); +} + +unique_ptr Connection::PendingQuery(const string &query, PendingQueryParameters parameters) { + return context->PendingQuery(query, parameters); } unique_ptr Connection::Prepare(const string &query) { @@ -165,7 +172,11 @@ unique_ptr Connection::Prepare(unique_ptr state unique_ptr Connection::QueryParamsRecursive(const string &query, vector &values) { auto named_params = ConvertParamListToMap(values); - auto pending = PendingQuery(query, named_params, false); + PendingQueryParameters parameters; + parameters.parameters = &named_params; + parameters.query_parameters.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + parameters.query_parameters.memory_type = QueryResultMemoryType::BUFFER_MANAGED; + auto pending = PendingQuery(query, parameters); if (pending->HasError()) { return make_uniq(pending->GetErrorObject()); } diff --git a/src/duckdb/src/main/database.cpp b/src/duckdb/src/main/database.cpp index 3d644d408..6e68d4e8d 100644 --- a/src/duckdb/src/main/database.cpp +++ b/src/duckdb/src/main/database.cpp @@ -23,6 +23,7 @@ #include "duckdb/storage/object_cache.hpp" #include "duckdb/storage/standard_buffer_manager.hpp" #include "duckdb/storage/storage_extension.hpp" +#include "duckdb/storage/block_allocator.hpp" #include "duckdb/storage/storage_manager.hpp" #include "duckdb/transaction/transaction_manager.hpp" #include "duckdb/main/capi/extension_api.hpp" @@ -32,6 +33,7 @@ #include "duckdb/common/http_util.hpp" #include "mbedtls_wrapper.hpp" #include "duckdb/main/database_file_path_manager.hpp" +#include "duckdb/main/result_set_manager.hpp" #ifndef DUCKDB_NO_THREADS #include "duckdb/common/thread.hpp" @@ -87,13 +89,12 @@ DatabaseInstance::~DatabaseInstance() { log_manager.reset(); external_file_cache.reset(); + result_set_manager.reset(); buffer_manager.reset(); // flush allocations and disable the background thread - if (Allocator::SupportsFlush()) { - Allocator::FlushAll(); - } + config.block_allocator->FlushAll(); Allocator::SetBackgroundThreads(false); // after all destruction is complete clear the cache entry config.db_cache_entry.reset(); @@ -283,10 +284,11 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf buffer_manager = make_uniq(*this, config.options.temporary_directory); } - log_manager = make_shared_ptr(*this, LogConfig()); + log_manager = make_uniq(*this, LogConfig()); log_manager->Initialize(); external_file_cache = make_uniq(*this, config.options.enable_external_file_cache); + result_set_manager = make_uniq(*this); scheduler = make_uniq(*this); object_cache = make_uniq(); @@ -382,6 +384,10 @@ ExternalFileCache &DatabaseInstance::GetExternalFileCache() { return *external_file_cache; } +ResultSetManager &DatabaseInstance::GetResultSetManager() { + return *result_set_manager; +} + ConnectionManager &DatabaseInstance::GetConnectionManager() { return *connection_manager; } @@ -455,6 +461,9 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path if (!config.allocator) { config.allocator = make_uniq(); } + config.block_allocator = make_uniq(*config.allocator, config.options.default_block_alloc_size, + DBConfig::GetSystemAvailableMemory(*config.file_system) * 8 / 10, + config.options.block_allocator_size); config.replacement_scans = std::move(new_config.replacement_scans); config.parser_extensions = std::move(new_config.parser_extensions); config.error_manager = std::move(new_config.error_manager); @@ -467,7 +476,7 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path if (new_config.buffer_pool) { config.buffer_pool = std::move(new_config.buffer_pool); } else { - config.buffer_pool = make_shared_ptr(config.options.maximum_memory, + config.buffer_pool = make_shared_ptr(*config.block_allocator, config.options.maximum_memory, config.options.buffer_manager_track_eviction_timestamps, config.options.allocator_bulk_deallocation_flush_threshold); } @@ -505,12 +514,18 @@ SettingLookupResult DatabaseInstance::TryGetCurrentSetting(const string &key, Va return db_config.TryGetCurrentSetting(key, result); } -shared_ptr DatabaseInstance::GetEncryptionUtil() const { +shared_ptr DatabaseInstance::GetEncryptionUtil() { + if (!config.encryption_util || !config.encryption_util->SupportsEncryption()) { + ExtensionHelper::TryAutoLoadExtension(*this, "httpfs"); + } + if (config.encryption_util) { return config.encryption_util; } - return make_shared_ptr(); + auto result = make_shared_ptr(); + + return std::move(result); } ValidChecker &DatabaseInstance::GetValidChecker() { diff --git a/src/duckdb/src/main/database_file_path_manager.cpp b/src/duckdb/src/main/database_file_path_manager.cpp index 05adeadfe..2e107210a 100644 --- a/src/duckdb/src/main/database_file_path_manager.cpp +++ b/src/duckdb/src/main/database_file_path_manager.cpp @@ -5,30 +5,57 @@ namespace duckdb { +DatabasePathInfo::DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode) + : name(std::move(name_p)), access_mode(access_mode) { + attached_databases.insert(manager); +} + idx_t DatabaseFilePathManager::ApproxDatabaseCount() const { lock_guard path_lock(db_paths_lock); return db_paths.size(); } -InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(const string &path, const string &name, - OnCreateConflict on_conflict, +InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(DatabaseManager &manager, const string &path, + const string &name, OnCreateConflict on_conflict, AttachOptions &options) { if (path.empty() || path == IN_MEMORY_PATH) { return InsertDatabasePathResult::SUCCESS; } lock_guard path_lock(db_paths_lock); - auto entry = db_paths.emplace(path, DatabasePathInfo(name)); + auto entry = db_paths.emplace(path, DatabasePathInfo(manager, name, options.access_mode)); if (!entry.second) { auto &existing = entry.first->second; + bool already_exists = false; + bool attached_in_this_system = false; if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT && existing.name == name) { - return InsertDatabasePathResult::ALREADY_EXISTS; + already_exists = true; + attached_in_this_system = existing.attached_databases.find(manager) != existing.attached_databases.end(); + } + if (options.access_mode == AccessMode::READ_ONLY && existing.access_mode == AccessMode::READ_ONLY) { + if (attached_in_this_system) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + // all attaches are in read-only mode - there is no conflict, just increase the reference count + existing.attached_databases.insert(manager); + existing.reference_count++; + } else { + if (already_exists) { + if (attached_in_this_system) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " + "the process of being detached", + name, path); + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " + "attached by database \"%s\"", + name, path, existing.name); } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " - "attached by database \"%s\"", - name, path, existing.name); } - options.stored_database_path = make_uniq(*this, path, name); + options.stored_database_path = make_uniq(manager, *this, path, name); return InsertDatabasePathResult::SUCCESS; } @@ -37,7 +64,25 @@ void DatabaseFilePathManager::EraseDatabasePath(const string &path) { return; } lock_guard path_lock(db_paths_lock); - db_paths.erase(path); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + if (entry->second.reference_count <= 1) { + db_paths.erase(entry); + } else { + entry->second.reference_count--; + } + } +} + +void DatabaseFilePathManager::DetachDatabase(DatabaseManager &manager, const string &path) { + if (path.empty() || path == IN_MEMORY_PATH) { + return; + } + lock_guard path_lock(db_paths_lock); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + entry->second.attached_databases.erase(manager); + } } } // namespace duckdb diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp index ae0a6447d..e4f3f89c5 100644 --- a/src/duckdb/src/main/database_manager.cpp +++ b/src/duckdb/src/main/database_manager.cpp @@ -84,7 +84,20 @@ shared_ptr DatabaseManager::GetDatabaseInternal(const lock_gua shared_ptr DatabaseManager::AttachDatabase(ClientContext &context, AttachInfo &info, AttachOptions &options) { + string extension = ""; + if (FileSystem::IsRemoteFile(info.path, extension)) { + if (options.access_mode == AccessMode::AUTOMATIC) { + // Attaching of remote files gets bumped to READ_ONLY + // This is due to the fact that on most (all?) remote files writes to DB are not available + // and having this raised later is not super helpful + options.access_mode = AccessMode::READ_ONLY; + } + } + if (options.db_type.empty() || StringUtil::CIEquals(options.db_type, "duckdb")) { + // Start timing the ATTACH-delay step. + auto profiler = context.client_data->profiler->StartTimer(MetricType::WAITING_TO_ATTACH_LATENCY); + while (InsertDatabasePath(info, options) == InsertDatabasePathResult::ALREADY_EXISTS) { // database with this name and path already exists // first check if it exists within this transaction @@ -94,12 +107,13 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont // it does! return it return existing_db; } + // ... but it might not be done attaching yet! // verify the database has actually finished attaching prior to returning lock_guard guard(databases_lock); auto entry = databases.find(info.name); if (entry != databases.end()) { - // database ACTUALLY exists - return it + // The database ACTUALLY exists, so we return it. return entry->second; } if (context.interrupted) { @@ -117,18 +131,11 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont if (AttachedDatabase::NameIsReserved(info.name)) { throw BinderException("Attached database name \"%s\" cannot be used because it is a reserved name", info.name); } - string extension = ""; - if (FileSystem::IsRemoteFile(info.path, extension)) { + if (!extension.empty()) { if (!ExtensionHelper::TryAutoLoadExtension(context, extension)) { throw MissingExtensionException("Attaching path '%s' requires extension '%s' to be loaded", info.path, extension); } - if (options.access_mode == AccessMode::AUTOMATIC) { - // Attaching of remote files gets bumped to READ_ONLY - // This is due to the fact that on most (all?) remote files writes to DB are not available - // and having this raised later is not super helpful - options.access_mode = AccessMode::READ_ONLY; - } } // now create the attached database @@ -270,7 +277,7 @@ idx_t DatabaseManager::ApproxDatabaseCount() { } InsertDatabasePathResult DatabaseManager::InsertDatabasePath(const AttachInfo &info, AttachOptions &options) { - return path_manager->InsertDatabasePath(info.path, info.name, info.on_conflict, options); + return path_manager->InsertDatabasePath(*this, info.path, info.name, info.on_conflict, options); } vector DatabaseManager::GetAttachedDatabasePaths() { @@ -293,7 +300,6 @@ vector DatabaseManager::GetAttachedDatabasePaths() { void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, const DBConfig &config, AttachOptions &options) { - // Test if the database is a DuckDB database file. if (StringUtil::CIEquals(options.db_type, "duckdb")) { options.db_type = ""; @@ -303,14 +309,15 @@ void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, // Try to extract the database type from the path. if (options.db_type.empty()) { auto &fs = FileSystem::GetFileSystem(context); - DBPathAndType::CheckMagicBytes(QueryContext(context), fs, info.path, options.db_type); + DBPathAndType::CheckMagicBytes(context, fs, info.path, options.db_type); } if (options.db_type.empty()) { return; } - if (config.storage_extensions.find(options.db_type) != config.storage_extensions.end()) { + auto extension_name = ExtensionHelper::ApplyExtensionAlias(options.db_type); + if (config.storage_extensions.find(extension_name) != config.storage_extensions.end()) { // If the database type is already registered, we don't need to load it again. return; } diff --git a/src/duckdb/src/main/db_instance_cache.cpp b/src/duckdb/src/main/db_instance_cache.cpp index 57f4ee457..1960c5ee3 100644 --- a/src/duckdb/src/main/db_instance_cache.cpp +++ b/src/duckdb/src/main/db_instance_cache.cpp @@ -137,9 +137,23 @@ shared_ptr DBInstanceCache::CreateInstance(const string &database, DBCon shared_ptr DBInstanceCache::GetOrCreateInstance(const string &database, DBConfig &config_dict, bool cache_instance, const std::function &on_create) { + auto cache_behavior = cache_instance ? CacheBehavior::ALWAYS_CACHE : CacheBehavior::NEVER_CACHE; + return GetOrCreateInstance(database, config_dict, cache_behavior, on_create); +} + +shared_ptr DBInstanceCache::GetOrCreateInstance(const string &database, DBConfig &config_dict, + CacheBehavior cache_behavior, + const std::function &on_create) { unique_lock lock(cache_lock, std::defer_lock); + bool cache_instance = cache_behavior == CacheBehavior::ALWAYS_CACHE; + if (cache_behavior == CacheBehavior::AUTOMATIC) { + // cache all unnamed in-memory connections + cache_instance = true; + if (database == IN_MEMORY_PATH || database.empty()) { + cache_instance = false; + } + } if (cache_instance) { - // While we do not own the lock, we cannot definitively say that the database instance does not exist. while (!lock.owns_lock()) { // The problem is, that we have to unlock the mutex in GetInstanceInternal, so we can non-blockingly wait diff --git a/src/duckdb/src/main/extension.cpp b/src/duckdb/src/main/extension.cpp index cd786d863..c982a4bc3 100644 --- a/src/duckdb/src/main/extension.cpp +++ b/src/duckdb/src/main/extension.cpp @@ -7,6 +7,8 @@ namespace duckdb { +constexpr const idx_t ParsedExtensionMetaData::FOOTER_SIZE; + Extension::~Extension() { } diff --git a/src/duckdb/src/main/extension/extension_alias.cpp b/src/duckdb/src/main/extension/extension_alias.cpp index 81d3c1e1b..4a1ae7146 100644 --- a/src/duckdb/src/main/extension/extension_alias.cpp +++ b/src/duckdb/src/main/extension/extension_alias.cpp @@ -10,6 +10,7 @@ static const ExtensionAlias internal_aliases[] = {{"http", "httpfs"}, // httpfs {"postgres", "postgres_scanner"}, // postgres {"sqlite", "sqlite_scanner"}, // sqlite {"sqlite3", "sqlite_scanner"}, + {"uc_catalog", "unity_catalog"}, // old name for compatibility {nullptr, nullptr}}; idx_t ExtensionHelper::ExtensionAliasCount() { diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp index 74add5379..9058f19e1 100644 --- a/src/duckdb/src/main/extension/extension_helper.cpp +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -175,7 +175,6 @@ bool ExtensionHelper::CanAutoloadExtension(const string &ext_name) { string ExtensionHelper::AddExtensionInstallHintToErrorMsg(ClientContext &context, const string &base_error, const string &extension_name) { - return AddExtensionInstallHintToErrorMsg(DatabaseInstance::GetDatabase(context), base_error, extension_name); } string ExtensionHelper::AddExtensionInstallHintToErrorMsg(DatabaseInstance &db, const string &base_error, @@ -405,138 +404,6 @@ void ExtensionHelper::AutoLoadExtension(DatabaseInstance &db, const string &exte } } -//===--------------------------------------------------------------------===// -// Load Statically Compiled Extension -//===--------------------------------------------------------------------===// -void ExtensionHelper::LoadAllExtensions(DuckDB &db) { - // The in-tree extensions that we check. Non-cmake builds are currently limited to these for static linking - // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds, after that - // these can be removed - vector extensions {"parquet", "icu", "tpch", "tpcds", "httpfs", "json", - "excel", "inet", "jemalloc", "autocomplete", "core_functions"}; - for (auto &ext : extensions) { - LoadExtensionInternal(db, ext, true); - } - -#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS - for (const auto &ext : LinkedExtensions()) { - LoadExtensionInternal(db, ext, true); - } -#endif -} - -ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string &extension) { - return LoadExtensionInternal(db, extension, false); -} - -ExtensionLoadResult ExtensionHelper::LoadExtensionInternal(DuckDB &db, const std::string &extension, - bool initial_load) { -#ifdef DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE - // Note: weird comma's are on purpose to do easy string contains on a list of extension names - if (!initial_load && StringUtil::Contains(DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE, "," + extension + ",")) { - Connection con(db); - auto result = con.Query((string) "LOAD '" + DUCKDB_EXTENSIONS_BUILD_PATH + "/" + extension + "/" + extension + - ".duckdb_extension'"); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - return ExtensionLoadResult::LOADED_EXTENSION; - } -#endif - - // This is the main extension loading mechanism that loads the extension that are statically linked. -#if defined(GENERATED_EXTENSION_HEADERS) && GENERATED_EXTENSION_HEADERS - if (TryLoadLinkedExtension(db, extension)) { - return ExtensionLoadResult::LOADED_EXTENSION; - } else { - return ExtensionLoadResult::NOT_LOADED; - } -#endif - - // This is the fallback to the "old" extension loading mechanism for non-cmake builds - // TODO: rewrite package_build.py to allow also loading out-of-tree extensions in non-cmake builds - if (extension == "parquet") { -#if DUCKDB_EXTENSION_PARQUET_LINKED - db.LoadStaticExtension(); -#else - // parquet extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "icu") { -#if DUCKDB_EXTENSION_ICU_LINKED - db.LoadStaticExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "tpch") { -#if DUCKDB_EXTENSION_TPCH_LINKED - db.LoadStaticExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "tpcds") { -#if DUCKDB_EXTENSION_TPCDS_LINKED - db.LoadStaticExtension(); -#else - // icu extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "httpfs") { -#if DUCKDB_EXTENSION_HTTPFS_LINKED - db.LoadStaticExtension(); -#else - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "json") { -#if DUCKDB_EXTENSION_JSON_LINKED - db.LoadStaticExtension(); -#else - // json extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "excel") { -#if DUCKDB_EXTENSION_EXCEL_LINKED - db.LoadStaticExtension(); -#else - // excel extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "jemalloc") { -#if DUCKDB_EXTENSION_JEMALLOC_LINKED - db.LoadStaticExtension(); -#else - // jemalloc extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "autocomplete") { -#if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED - db.LoadStaticExtension(); -#else - // autocomplete extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "inet") { -#if DUCKDB_EXTENSION_INET_LINKED - db.LoadStaticExtension(); -#else - // inet extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } else if (extension == "core_functions") { -#if DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED - db.LoadStaticExtension(); -#else - // core_functions extension required but not build: skip this test - return ExtensionLoadResult::NOT_LOADED; -#endif - } - - return ExtensionLoadResult::LOADED_EXTENSION; -} - static const char *const public_keys[] = { R"( -----BEGIN PUBLIC KEY----- diff --git a/src/duckdb/src/main/extension/extension_install.cpp b/src/duckdb/src/main/extension/extension_install.cpp index 5be3dded1..36092f307 100644 --- a/src/duckdb/src/main/extension/extension_install.cpp +++ b/src/duckdb/src/main/extension/extension_install.cpp @@ -58,56 +58,100 @@ string ExtensionHelper::ExtensionInstallDocumentationLink(const string &extensio return link; } -duckdb::string ExtensionHelper::DefaultExtensionFolder(FileSystem &fs) { - string home_directory = fs.GetHomeDirectory(); - // exception if the home directory does not exist, don't create whatever we think is home - if (!fs.DirectoryExists(home_directory)) { - throw IOException("Can't find the home directory at '%s'\nSpecify a home directory using the SET " - "home_directory='/path/to/dir' option.", - home_directory); - } - string res = home_directory; - res = fs.JoinPath(res, ".duckdb"); - res = fs.JoinPath(res, "extensions"); - return res; +vector ExtensionHelper::DefaultExtensionFolders(FileSystem &fs) { + vector default_folders; +// These fallbacks are necessary if the user doesn't use the CMake build. +#ifndef DUCKDB_EXTENSION_DIRECTORIES +#ifdef _WIN32 +#define DUCKDB_EXTENSION_DIRECTORIES "~\\.duckdb\\extensions" +#else +#define DUCKDB_EXTENSION_DIRECTORIES "~/.duckdb/extensions" +#endif +#endif + string dirs_string(DUCKDB_EXTENSION_DIRECTORIES); + + // Skip if empty + if (dirs_string.empty()) { + return default_folders; + } + + // Split the string by separator + auto directories = StringUtil::Split(dirs_string, ';'); + + for (auto &dir : directories) { + // Skip empty directories + if (dir.empty()) { + continue; + } + + default_folders.push_back(dir); + } + + return default_folders; } -string ExtensionHelper::GetExtensionDirectoryPath(ClientContext &context) { +vector ExtensionHelper::GetExtensionDirectoryPath(ClientContext &context) { auto &db = DatabaseInstance::GetDatabase(context); auto &fs = FileSystem::GetFileSystem(context); return GetExtensionDirectoryPath(db, fs); } -string ExtensionHelper::GetExtensionDirectoryPath(DatabaseInstance &db, FileSystem &fs) { - string extension_directory; +vector ExtensionHelper::GetExtensionDirectoryPath(DatabaseInstance &db, FileSystem &fs) { + vector extension_directories; auto &config = db.config; - if (!config.options.extension_directory.empty()) { // create the extension directory if not present - extension_directory = config.options.extension_directory; - // TODO this should probably live in the FileSystem - // convert random separators to platform-canonic - } else { // otherwise default to home - extension_directory = DefaultExtensionFolder(fs); + + if (!config.options.extension_directory.empty()) { + extension_directories.push_back(config.options.extension_directory); } - extension_directory = fs.ConvertSeparators(extension_directory); - // expand ~ in extension directory - extension_directory = fs.ExpandPath(extension_directory); + if (!config.options.extension_directories.empty()) { + // Add all configured extension directories + for (const auto &dir : config.options.extension_directories) { + extension_directories.push_back(dir); + } + } + if (extension_directories.empty()) { + // Add default extension directory if no custom directories configured + for (const auto &default_dir : ExtensionHelper::DefaultExtensionFolders(fs)) { + extension_directories.push_back(default_dir); + } + } + // Process all directories with common path operations auto path_components = PathComponents(); - for (auto &path_ele : path_components) { - extension_directory = fs.JoinPath(extension_directory, path_ele); + for (auto &extension_directory : extension_directories) { + // convert random separators to platform-canonic + extension_directory = fs.ConvertSeparators(extension_directory); + // expand ~ in extension directory + extension_directory = fs.ExpandPath(extension_directory); + + // Add path components (version and platform) + for (auto &path_ele : path_components) { + extension_directory = fs.JoinPath(extension_directory, path_ele); + } } - return extension_directory; + return extension_directories; } string ExtensionHelper::ExtensionDirectory(DatabaseInstance &db, FileSystem &fs) { #ifdef WASM_LOADABLE_EXTENSIONS throw PermissionException("ExtensionDirectory functionality is not supported in duckdb-wasm"); #endif - string extension_directory = GetExtensionDirectoryPath(db, fs); + auto extension_directories = GetExtensionDirectoryPath(db, fs); + // TODO: This should never be the case given the implementation of GetExtensionDirectoryPath + // should we still keep this check? + D_ASSERT(!extension_directories.empty()); + + string extension_directory = extension_directories[0]; // Use first/primary directory { if (!fs.DirectoryExists(extension_directory)) { + string home_directory = fs.GetHomeDirectory(); + if (extension_directory.rfind(home_directory, 0) == 0 && !fs.DirectoryExists(home_directory)) { + throw IOException("Can't find the home directory at '%s'\nSpecify a home directory using the SET " + "home_directory='/path/to/dir' option.", + home_directory); + } fs.CreateDirectoriesRecursive(extension_directory); } } @@ -246,12 +290,6 @@ static void WriteExtensionFiles(FileSystem &fs, const string &temp_path, const s auto metadata_file_path = local_extension_path + ".info"; WriteExtensionMetadataFileToDisk(fs, metadata_tmp_path, info); - // First remove the local extension we are about to replace - fs.TryRemoveFile(local_extension_path); - - // Then remove the old metadata file - fs.TryRemoveFile(metadata_file_path); - fs.MoveFile(metadata_tmp_path, metadata_file_path); fs.MoveFile(temp_path, local_extension_path); } diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp index 96e559ec0..0b88a2071 100644 --- a/src/duckdb/src/main/extension/extension_load.cpp +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -76,19 +76,19 @@ struct ExtensionAccess { load_state.has_error = true; load_state.error_data = error ? ErrorData(error) - : ErrorData(ExceptionType::UNKNOWN_TYPE, "Extension has indicated an error occured during " + : ErrorData(ExceptionType::UNKNOWN_TYPE, "Extension has indicated an error occurred during " "initialization, but did not set an error message."); } //! Called by the extension get a pointer to the database that is loading it - static duckdb_database *GetDatabase(duckdb_extension_info info) { + static duckdb_database GetDatabase(duckdb_extension_info info) { auto &load_state = DuckDBExtensionLoadState::Get(info); try { // Create the duckdb_database load_state.database_data = make_uniq(); load_state.database_data->database = make_shared_ptr(load_state.db); - return reinterpret_cast(load_state.database_data.get()); + return reinterpret_cast(load_state.database_data.get()); } catch (std::exception &ex) { load_state.has_error = true; load_state.error_data = ErrorData(ex); @@ -350,19 +350,53 @@ bool ExtensionHelper::TryInitialLoad(DatabaseInstance &db, FileSystem &fs, const filename = address; #else - string local_path = !db.config.options.extension_directory.empty() - ? db.config.options.extension_directory - : ExtensionHelper::DefaultExtensionFolder(fs); - - // convert random separators to platform-canonic - local_path = fs.ConvertSeparators(local_path); - // expand ~ in extension directory - local_path = fs.ExpandPath(local_path); - auto path_components = PathComponents(); - for (auto &path_ele : path_components) { - local_path = fs.JoinPath(local_path, path_ele); + // Local function to process local path + auto ComputeLocalExtensionPath = [&fs](const string &base_path, const string &extension_name) -> string { + // convert random separators to platform-canonic + string local_path = fs.ConvertSeparators(base_path); + // expand ~ in extension directory + local_path = fs.ExpandPath(local_path); + auto path_components = PathComponents(); + for (auto &path_ele : path_components) { + local_path = fs.JoinPath(local_path, path_ele); + } + return fs.JoinPath(local_path, extension_name + ".duckdb_extension"); + }; + + // Collect all directories to search for extensions + vector search_directories; + if (!db.config.options.extension_directory.empty()) { + search_directories.push_back(db.config.options.extension_directory); + } + + if (!db.config.options.extension_directories.empty()) { + // Add all configured extension directories + for (const auto &dir : db.config.options.extension_directories) { + search_directories.push_back(dir); + } + } + + // Add default extension directory if no custom directories configured + if (search_directories.empty()) { + for (const auto &path : ExtensionHelper::DefaultExtensionFolders(fs)) { + search_directories.push_back(path); + } + } + + // Try each directory in sequence until extension is found + bool found = false; + for (const auto &directory : search_directories) { + filename = ComputeLocalExtensionPath(directory, extension_name); + if (fs.FileExists(filename)) { + found = true; + break; + } + } + + // If not found in any directory, use the first directory for error reporting + if (!found) { + filename = ComputeLocalExtensionPath(search_directories[0], extension_name); } - filename = fs.JoinPath(local_path, extension_name + ".duckdb_extension"); #endif } else { direct_load = true; @@ -591,7 +625,7 @@ void ExtensionHelper::LoadExternalExtensionInternal(DatabaseInstance &db, FileSy if (result == false) { throw FatalException( "Extension '%s' failed to initialize but did not return an error. This indicates an " - "error in the extension: C API extensions should return a boolean `true` to indicate succesful " + "error in the extension: C API extensions should return a boolean `true` to indicate successful " "initialization. " "This means that the Extension may be partially initialized resulting in an inconsistent state of " "DuckDB.", diff --git a/src/duckdb/src/main/http/http_util.cpp b/src/duckdb/src/main/http/http_util.cpp index a51fb3e7f..fb5a9491f 100644 --- a/src/duckdb/src/main/http/http_util.cpp +++ b/src/duckdb/src/main/http/http_util.cpp @@ -130,9 +130,12 @@ BaseRequest::BaseRequest(RequestType type, const string &url, const HTTPHeaders class HTTPLibClient : public HTTPClient { public: HTTPLibClient(HTTPParams &http_params, const string &proto_host_port) { + client = make_uniq(proto_host_port); + Initialize(http_params); + } + void Initialize(HTTPParams &http_params) override { auto sec = static_cast(http_params.timeout); auto usec = static_cast(http_params.timeout_usec); - client = make_uniq(proto_host_port); client->set_follow_location(http_params.follow_location); client->set_keep_alive(http_params.keep_alive); client->set_write_timeout(sec, usec); @@ -228,12 +231,27 @@ unique_ptr HTTPUtil::SendRequest(BaseRequest &request, unique_ptr< std::function(void)> on_request([&]() { unique_ptr response; + + // When logging is enabled, we collect request timings + if (request.params.logger) { + request.have_request_timing = request.params.logger->ShouldLog(HTTPLogType::NAME, HTTPLogType::LEVEL); + } + try { + if (request.have_request_timing) { + request.request_start = Timestamp::GetCurrentTimestamp(); + } response = client->Request(request); } catch (...) { + if (request.have_request_timing) { + request.request_end = Timestamp::GetCurrentTimestamp(); + } LogRequest(request, nullptr); throw; } + if (request.have_request_timing) { + request.request_end = Timestamp::GetCurrentTimestamp(); + } LogRequest(request, response ? response.get() : nullptr); return response; }); @@ -367,7 +385,9 @@ HTTPUtil::RunRequestWithRetry(const std::function(void) try { response = on_request(); - response->url = request.url; + if (response) { + response->url = request.url; + } } catch (IOException &e) { exception_error = e.what(); caught_e = std::current_exception(); diff --git a/src/duckdb/src/main/materialized_query_result.cpp b/src/duckdb/src/main/materialized_query_result.cpp index d319d5686..cd8c006ba 100644 --- a/src/duckdb/src/main/materialized_query_result.cpp +++ b/src/duckdb/src/main/materialized_query_result.cpp @@ -84,11 +84,7 @@ unique_ptr MaterializedQueryResult::TakeCollection() { return std::move(collection); } -unique_ptr MaterializedQueryResult::Fetch() { - return FetchRaw(); -} - -unique_ptr MaterializedQueryResult::FetchRaw() { +unique_ptr MaterializedQueryResult::FetchInternal() { if (HasError()) { throw InvalidInputException("Attempting to fetch from an unsuccessful query result\nError: %s", GetError()); } diff --git a/src/duckdb/src/main/prepared_statement.cpp b/src/duckdb/src/main/prepared_statement.cpp index 34d12cadc..49ff9ac94 100644 --- a/src/duckdb/src/main/prepared_statement.cpp +++ b/src/duckdb/src/main/prepared_statement.cpp @@ -110,7 +110,10 @@ unique_ptr PreparedStatement::PendingQuery(case_insensitive_ } D_ASSERT(data); - parameters.allow_stream_result = allow_stream_result && data->properties.allow_stream_result; + parameters.query_parameters.output_type = + allow_stream_result && data->properties.output_type == QueryResultOutputType::ALLOW_STREAMING + ? QueryResultOutputType::ALLOW_STREAMING + : QueryResultOutputType::FORCE_MATERIALIZED; auto result = context->PendingQuery(query, data, parameters); // The result should not contain any reference to the 'vector parameters.parameters' return result; diff --git a/src/duckdb/src/main/prepared_statement_data.cpp b/src/duckdb/src/main/prepared_statement_data.cpp index 63069ceb1..7ef2edb41 100644 --- a/src/duckdb/src/main/prepared_statement_data.cpp +++ b/src/duckdb/src/main/prepared_statement_data.cpp @@ -71,7 +71,7 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, } } for (auto &it : properties.modified_databases) { - if (!CheckCatalogIdentity(context, it.first, it.second)) { + if (!CheckCatalogIdentity(context, it.first, it.second.identity)) { return true; } } diff --git a/src/duckdb/src/main/profiling_info.cpp b/src/duckdb/src/main/profiling_info.cpp index 8f744d51b..ff709b7de 100644 --- a/src/duckdb/src/main/profiling_info.cpp +++ b/src/duckdb/src/main/profiling_info.cpp @@ -1,7 +1,8 @@ #include "duckdb/main/profiling_info.hpp" #include "duckdb/common/enum_util.hpp" -#include "duckdb/main/query_profiler.hpp" +#include "duckdb/main/profiling_utils.hpp" +#include "duckdb/logging/log_manager.hpp" #include "yyjson.hpp" @@ -11,11 +12,9 @@ namespace duckdb { ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t depth) : settings(n_settings) { // Expand. - if (depth == 0) { - settings.insert(MetricsType::QUERY_NAME); - } else { - settings.insert(MetricsType::OPERATOR_NAME); - settings.insert(MetricsType::OPERATOR_TYPE); + if (depth > 0) { + settings.insert(MetricType::OPERATOR_NAME); + settings.insert(MetricType::OPERATOR_TYPE); } for (const auto &metric : settings) { Expand(expanded_settings, metric); @@ -23,12 +22,12 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t // Reduce. if (depth == 0) { - auto op_metrics = DefaultOperatorSettings(); + auto op_metrics = MetricsUtils::GetOperatorMetrics(); for (const auto metric : op_metrics) { settings.erase(metric); } } else { - auto root_metrics = DefaultRootSettings(); + auto root_metrics = MetricsUtils::GetRootScopeMetrics(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -36,37 +35,6 @@ ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t ResetMetrics(); } -profiler_settings_t ProfilingInfo::DefaultSettings() { - return {MetricsType::QUERY_NAME, - MetricsType::BLOCKED_THREAD_TIME, - MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, - MetricsType::CPU_TIME, - MetricsType::EXTRA_INFO, - MetricsType::CUMULATIVE_CARDINALITY, - MetricsType::OPERATOR_NAME, - MetricsType::OPERATOR_TYPE, - MetricsType::OPERATOR_CARDINALITY, - MetricsType::CUMULATIVE_ROWS_SCANNED, - MetricsType::OPERATOR_ROWS_SCANNED, - MetricsType::OPERATOR_TIMING, - MetricsType::RESULT_SET_SIZE, - MetricsType::LATENCY, - MetricsType::ROWS_RETURNED, - MetricsType::TOTAL_BYTES_READ, - MetricsType::TOTAL_BYTES_WRITTEN}; -} - -profiler_settings_t ProfilingInfo::DefaultRootSettings() { - return {MetricsType::QUERY_NAME, MetricsType::BLOCKED_THREAD_TIME, MetricsType::LATENCY, - MetricsType::ROWS_RETURNED}; -} - -profiler_settings_t ProfilingInfo::DefaultOperatorSettings() { - return {MetricsType::OPERATOR_CARDINALITY, MetricsType::OPERATOR_ROWS_SCANNED, MetricsType::OPERATOR_TIMING, - MetricsType::OPERATOR_NAME, MetricsType::OPERATOR_TYPE}; -} - void ProfilingInfo::ResetMetrics() { metrics.clear(); for (auto &metric : expanded_settings) { @@ -75,64 +43,32 @@ void ProfilingInfo::ResetMetrics() { continue; } - switch (metric) { - case MetricsType::QUERY_NAME: - metrics[metric] = Value::CreateValue(""); - break; - case MetricsType::LATENCY: - case MetricsType::BLOCKED_THREAD_TIME: - case MetricsType::CPU_TIME: - case MetricsType::OPERATOR_TIMING: - metrics[metric] = Value::CreateValue(0.0); - break; - case MetricsType::OPERATOR_NAME: - metrics[metric] = Value::CreateValue(""); - break; - case MetricsType::OPERATOR_TYPE: - metrics[metric] = Value::CreateValue(0); - break; - case MetricsType::ROWS_RETURNED: - case MetricsType::RESULT_SET_SIZE: - case MetricsType::CUMULATIVE_CARDINALITY: - case MetricsType::OPERATOR_CARDINALITY: - case MetricsType::CUMULATIVE_ROWS_SCANNED: - case MetricsType::OPERATOR_ROWS_SCANNED: - case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricsType::TOTAL_BYTES_READ: - case MetricsType::TOTAL_BYTES_WRITTEN: - metrics[metric] = Value::CreateValue(0); - break; - case MetricsType::EXTRA_INFO: - break; - default: - throw InternalException("MetricsType" + EnumUtil::ToString(metric) + "not implemented"); - } + ProfilingUtils::SetMetricToDefault(metrics, metric); } } -bool ProfilingInfo::Enabled(const profiler_settings_t &settings, const MetricsType metric) { +bool ProfilingInfo::Enabled(const profiler_settings_t &settings, const MetricType metric) { if (settings.find(metric) != settings.end()) { return true; } return false; } -void ProfilingInfo::Expand(profiler_settings_t &settings, const MetricsType metric) { +void ProfilingInfo::Expand(profiler_settings_t &settings, const MetricType metric) { settings.insert(metric); switch (metric) { - case MetricsType::CPU_TIME: - settings.insert(MetricsType::OPERATOR_TIMING); + case MetricType::CPU_TIME: + settings.insert(MetricType::OPERATOR_TIMING); return; - case MetricsType::CUMULATIVE_CARDINALITY: - settings.insert(MetricsType::OPERATOR_CARDINALITY); + case MetricType::CUMULATIVE_CARDINALITY: + settings.insert(MetricType::OPERATOR_CARDINALITY); return; - case MetricsType::CUMULATIVE_ROWS_SCANNED: - settings.insert(MetricsType::OPERATOR_ROWS_SCANNED); + case MetricType::CUMULATIVE_ROWS_SCANNED: + settings.insert(MetricType::OPERATOR_ROWS_SCANNED); return; - case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricsType::ALL_OPTIMIZERS: { + case MetricType::CUMULATIVE_OPTIMIZER_TIMING: + case MetricType::ALL_OPTIMIZERS: { auto optimizer_metrics = MetricsUtils::GetOptimizerMetrics(); for (const auto optimizer_metric : optimizer_metrics) { settings.insert(optimizer_metric); @@ -144,52 +80,58 @@ void ProfilingInfo::Expand(profiler_settings_t &settings, const MetricsType metr } } -string ProfilingInfo::GetMetricAsString(const MetricsType metric) const { +string ProfilingInfo::GetMetricAsString(const MetricType metric) const { if (!Enabled(settings, metric)) { throw InternalException("Metric %s not enabled", EnumUtil::ToString(metric)); } - if (metric == MetricsType::EXTRA_INFO) { - string result; - for (auto &it : extra_info) { - if (!result.empty()) { - result += ", "; - } - result += StringUtil::Format("%s: %s", it.first, it.second); - } - return "\"" + result + "\""; - } - // The metric cannot be NULL and must be initialized. D_ASSERT(!metrics.at(metric).IsNull()); - if (metric == MetricsType::OPERATOR_TYPE) { - auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); + if (metric == MetricType::OPERATOR_TYPE) { + const auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); return EnumUtil::ToString(type); } return metrics.at(metric).ToString(); } +void ProfilingInfo::WriteMetricsToLog(ClientContext &context) { + auto &logger = Logger::Get(context); + if (logger.ShouldLog(MetricsLogType::NAME, MetricsLogType::LEVEL)) { + for (auto &metric : settings) { + logger.WriteLog(MetricsLogType::NAME, MetricsLogType::LEVEL, + MetricsLogType::ConstructLogMessage(metric, metrics[metric])); + } + } +} + void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest) { for (auto &metric : settings) { auto metric_str = StringUtil::Lower(EnumUtil::ToString(metric)); auto key_val = yyjson_mut_strcpy(doc, metric_str.c_str()); auto key_ptr = yyjson_mut_get_str(key_val); - if (metric == MetricsType::EXTRA_INFO) { + if (metric == MetricType::EXTRA_INFO) { auto extra_info_obj = yyjson_mut_obj(doc); - for (auto &it : extra_info) { - auto &key = it.first; - auto &value = it.second; - auto splits = StringUtil::Split(value, "\n"); + auto extra_info = metrics.at(metric); + auto children = MapValue::GetChildren(extra_info); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + + auto key_mut = unsafe_yyjson_mut_strncpy(doc, key.c_str(), key.size()); + auto value_mut = unsafe_yyjson_mut_strncpy(doc, value.c_str(), value.size()); + + auto splits = StringUtil::Split(value_mut, "\n"); if (splits.size() > 1) { auto list_items = yyjson_mut_arr(doc); for (auto &split : splits) { yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); } - yyjson_mut_obj_add_val(doc, extra_info_obj, key.c_str(), list_items); + yyjson_mut_obj_add_val(doc, extra_info_obj, key_mut, list_items); } else { - yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key.c_str(), value.c_str()); + yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key_mut, value_mut); } } yyjson_mut_obj_add_val(doc, dest, key_ptr, extra_info_obj); @@ -204,38 +146,7 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest continue; } - switch (metric) { - case MetricsType::QUERY_NAME: - case MetricsType::OPERATOR_NAME: - yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, metrics[metric].GetValue().c_str()); - break; - case MetricsType::LATENCY: - case MetricsType::BLOCKED_THREAD_TIME: - case MetricsType::CPU_TIME: - case MetricsType::OPERATOR_TIMING: { - yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[metric].GetValue()); - break; - } - case MetricsType::OPERATOR_TYPE: { - yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, GetMetricAsString(metric).c_str()); - break; - } - case MetricsType::ROWS_RETURNED: - case MetricsType::RESULT_SET_SIZE: - case MetricsType::CUMULATIVE_CARDINALITY: - case MetricsType::OPERATOR_CARDINALITY: - case MetricsType::CUMULATIVE_ROWS_SCANNED: - case MetricsType::OPERATOR_ROWS_SCANNED: - case MetricsType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricsType::TOTAL_BYTES_READ: - case MetricsType::TOTAL_BYTES_WRITTEN: { - yyjson_mut_obj_add_uint(doc, dest, key_ptr, metrics[metric].GetValue()); - break; - } - default: - throw NotImplementedException("MetricsType %s not implemented", EnumUtil::ToString(metric)); - } + ProfilingUtils::MetricToJson(doc, dest, key_ptr, metrics, metric); } } diff --git a/src/duckdb/src/main/profiling_utils.cpp b/src/duckdb/src/main/profiling_utils.cpp new file mode 100644 index 000000000..20597c15f --- /dev/null +++ b/src/duckdb/src/main/profiling_utils.cpp @@ -0,0 +1,215 @@ +// This file is automatically generated by scripts/generate_metric_enums.py +// Do not edit this file manually, your changes will be overwritten + +#include "duckdb/main/profiling_utils.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/main/profiling_node.hpp" +#include "duckdb/main/query_profiler.hpp" + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +static string OperatorToString(const Value &val) { + const auto type = static_cast(val.GetValue()); + return EnumUtil::ToString(type); +} + +template +static void AggregateMetric(ProfilingNode &node, MetricType aggregated_metric, MetricType child_metric, const std::function &update_fun) { + auto &info = node.GetProfilingInfo(); + info.metrics[aggregated_metric] = info.metrics[child_metric]; + + for (idx_t i = 0; i < node.GetChildCount(); i++) { + auto child = node.GetChild(i); + AggregateMetric(*child, aggregated_metric, child_metric, update_fun); + + auto &child_info = child->GetProfilingInfo(); + auto value = child_info.GetMetricValue(aggregated_metric); + info.MetricUpdate(aggregated_metric, value, update_fun); + } +} + +template +static void GetCumulativeMetric(ProfilingNode &node, MetricType cumulative_metric, MetricType child_metric) { + AggregateMetric( + node, cumulative_metric, child_metric, + [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { return old_value + new_value; }); +} + +static Value GetCumulativeOptimizers(ProfilingNode &node) { + auto &metrics = node.GetProfilingInfo().metrics; + double count = 0; + for (auto &metric : metrics) { + if (MetricsUtils::IsOptimizerMetric(metric.first)) { + count += metric.second.GetValue(); + } + } + return Value::CreateValue(count); +} + +void ProfilingUtils::SetMetricToDefault(profiler_metrics_t &metrics, const MetricType &type) { + switch(type) { + case MetricType::ALL_OPTIMIZERS: + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricType::ATTACH_REPLAY_WAL_LATENCY: + case MetricType::BLOCKED_THREAD_TIME: + case MetricType::CHECKPOINT_LATENCY: + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: + case MetricType::CPU_TIME: + case MetricType::CUMULATIVE_OPTIMIZER_TIMING: + case MetricType::LATENCY: + case MetricType::OPERATOR_TIMING: + case MetricType::PHYSICAL_PLANNER: + case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: + case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: + case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: + case MetricType::PLANNER: + case MetricType::PLANNER_BINDING: + case MetricType::WAITING_TO_ATTACH_LATENCY: + case MetricType::WRITE_TO_WAL_LATENCY: + metrics[type] = Value::CreateValue(0.0); + break; + case MetricType::CUMULATIVE_CARDINALITY: + case MetricType::CUMULATIVE_ROWS_SCANNED: + case MetricType::OPERATOR_CARDINALITY: + case MetricType::OPERATOR_ROWS_SCANNED: + case MetricType::RESULT_SET_SIZE: + case MetricType::ROWS_RETURNED: + case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: + case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricType::TOTAL_BYTES_READ: + case MetricType::TOTAL_BYTES_WRITTEN: + case MetricType::TOTAL_MEMORY_ALLOCATED: + case MetricType::WAL_REPLAY_ENTRY_COUNT: + metrics[type] = Value::CreateValue(0); + break; + case MetricType::EXTRA_INFO: + metrics[type] = Value::MAP(InsertionOrderPreservingMap()); + break; + case MetricType::OPERATOR_NAME: + case MetricType::QUERY_NAME: + metrics[type] = Value::CreateValue(""); + break; + case MetricType::OPERATOR_TYPE: + metrics[type] = Value::CreateValue(0); + break; + default: + throw InternalException("Unknown metric type %s", EnumUtil::ToString(type)); + } +} + +void ProfilingUtils::MetricToJson(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *dest, const char *key_ptr, profiler_metrics_t &metrics, const MetricType &type) { + switch(type) { + case MetricType::ALL_OPTIMIZERS: + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: + case MetricType::ATTACH_REPLAY_WAL_LATENCY: + case MetricType::BLOCKED_THREAD_TIME: + case MetricType::CHECKPOINT_LATENCY: + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: + case MetricType::CPU_TIME: + case MetricType::CUMULATIVE_OPTIMIZER_TIMING: + case MetricType::LATENCY: + case MetricType::OPERATOR_TIMING: + case MetricType::PHYSICAL_PLANNER: + case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: + case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: + case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: + case MetricType::PLANNER: + case MetricType::PLANNER_BINDING: + case MetricType::WAITING_TO_ATTACH_LATENCY: + case MetricType::WRITE_TO_WAL_LATENCY: + yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[type].GetValue()); + break; + case MetricType::CUMULATIVE_CARDINALITY: + case MetricType::CUMULATIVE_ROWS_SCANNED: + case MetricType::OPERATOR_CARDINALITY: + case MetricType::OPERATOR_ROWS_SCANNED: + case MetricType::RESULT_SET_SIZE: + case MetricType::ROWS_RETURNED: + case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: + case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: + case MetricType::TOTAL_BYTES_READ: + case MetricType::TOTAL_BYTES_WRITTEN: + case MetricType::TOTAL_MEMORY_ALLOCATED: + case MetricType::WAL_REPLAY_ENTRY_COUNT: + yyjson_mut_obj_add_uint(doc, dest, key_ptr, metrics[type].GetValue()); + break; + case MetricType::EXTRA_INFO: + break; + case MetricType::OPERATOR_NAME: + case MetricType::QUERY_NAME: + yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, metrics[type].GetValue().c_str()); + break; + case MetricType::OPERATOR_TYPE: + yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, OperatorToString(metrics[type]).c_str()); + break; + default: + throw InternalException("Unknown metric type %s", EnumUtil::ToString(type)); + } +} + +void ProfilingUtils::CollectMetrics(const MetricType &type, QueryMetrics &query_metrics, Value &metric, ProfilingNode &node, ProfilingInfo &child_info) { + switch(type) { + case MetricType::CPU_TIME: + GetCumulativeMetric(node, MetricType::CPU_TIME, MetricType::OPERATOR_TIMING); + break; + case MetricType::CUMULATIVE_CARDINALITY: + GetCumulativeMetric(node, MetricType::CUMULATIVE_CARDINALITY, MetricType::OPERATOR_CARDINALITY); + break; + case MetricType::CUMULATIVE_ROWS_SCANNED: + GetCumulativeMetric(node, MetricType::CUMULATIVE_ROWS_SCANNED, MetricType::OPERATOR_ROWS_SCANNED); + break; + case MetricType::ATTACH_LOAD_STORAGE_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::ATTACH_LOAD_STORAGE_LATENCY)); + break; + case MetricType::ATTACH_REPLAY_WAL_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::ATTACH_REPLAY_WAL_LATENCY)); + break; + case MetricType::CHECKPOINT_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::CHECKPOINT_LATENCY)); + break; + case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::COMMIT_LOCAL_STORAGE_LATENCY)); + break; + case MetricType::LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::LATENCY)); + break; + case MetricType::WAITING_TO_ATTACH_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::WAITING_TO_ATTACH_LATENCY)); + break; + case MetricType::WRITE_TO_WAL_LATENCY: + metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::WRITE_TO_WAL_LATENCY)); + break; + case MetricType::QUERY_NAME: + metric = query_metrics.query_name; + break; + case MetricType::TOTAL_BYTES_READ: + metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_READ)); + break; + case MetricType::TOTAL_BYTES_WRITTEN: + metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_WRITTEN)); + break; + case MetricType::TOTAL_MEMORY_ALLOCATED: + metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_MEMORY_ALLOCATED)); + break; + case MetricType::WAL_REPLAY_ENTRY_COUNT: + metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::WAL_REPLAY_ENTRY_COUNT)); + break; + case MetricType::RESULT_SET_SIZE: + metric = child_info.metrics[MetricType::RESULT_SET_SIZE]; + break; + case MetricType::ROWS_RETURNED: + metric = child_info.metrics[MetricType::OPERATOR_CARDINALITY]; + break; + case MetricType::CUMULATIVE_OPTIMIZER_TIMING: + metric = GetCumulativeOptimizers(node); + break; + default: + return; + } +} + +} diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp index 4c9c9328a..7d3092063 100644 --- a/src/duckdb/src/main/query_profiler.cpp +++ b/src/duckdb/src/main/query_profiler.cpp @@ -13,9 +13,11 @@ #include "duckdb/main/client_config.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/main/profiling_utils.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "yyjson.hpp" +#include "yyjson_utils.hpp" #include #include @@ -52,6 +54,8 @@ ProfilerPrintFormat QueryProfiler::GetPrintFormat(ExplainFormat format) const { return ProfilerPrintFormat::HTML; case ExplainFormat::GRAPHVIZ: return ProfilerPrintFormat::GRAPHVIZ; + case ExplainFormat::MERMAID: + return ProfilerPrintFormat::MERMAID; default: throw NotImplementedException("No mapping from ExplainFormat::%s to ProfilerPrintFormat", EnumUtil::ToString(format)); @@ -69,6 +73,8 @@ ExplainFormat QueryProfiler::GetExplainFormat(ProfilerPrintFormat format) const return ExplainFormat::HTML; case ProfilerPrintFormat::GRAPHVIZ: return ExplainFormat::GRAPHVIZ; + case ProfilerPrintFormat::MERMAID: + return ExplainFormat::MERMAID; case ProfilerPrintFormat::NO_OUTPUT: throw InternalException("Should not attempt to get ExplainFormat for ProfilerPrintFormat::NO_OUTPUT"); default: @@ -92,8 +98,8 @@ QueryProfiler &QueryProfiler::Get(ClientContext &context) { void QueryProfiler::Start(const string &query) { Reset(); running = true; - query_metrics.query = query; - query_metrics.latency.Start(); + query_metrics.query_name = query; + query_metrics.latency_timer = make_uniq(StartTimer(MetricType::LATENCY)); } void QueryProfiler::Reset() { @@ -102,9 +108,7 @@ void QueryProfiler::Reset() { phase_timings.clear(); phase_stack.clear(); running = false; - query_metrics.query = ""; - query_metrics.total_bytes_read = 0; - query_metrics.total_bytes_written = 0; + query_metrics.Reset(); } void QueryProfiler::StartQuery(const string &query, bool is_explain_analyze_p, bool start_at_optimizer) { @@ -121,7 +125,7 @@ void QueryProfiler::StartQuery(const string &query, bool is_explain_analyze_p, b } if (running) { // Called while already running: this should only happen when we print optimizer output - D_ASSERT(PrintOptimizerOutput()); + // D_ASSERT(PrintOptimizerOutput()); return; } Start(query); @@ -177,13 +181,12 @@ void QueryProfiler::Finalize(ProfilingNode &node) { Finalize(*child); auto &info = node.GetProfilingInfo(); - auto type = PhysicalOperatorType(info.GetMetricValue(MetricsType::OPERATOR_TYPE)); + auto type = PhysicalOperatorType(info.GetMetricValue(MetricType::OPERATOR_TYPE)); if (type == PhysicalOperatorType::UNION && - info.Enabled(info.expanded_settings, MetricsType::OPERATOR_CARDINALITY)) { - + info.Enabled(info.expanded_settings, MetricType::OPERATOR_CARDINALITY)) { auto &child_info = child->GetProfilingInfo(); - auto value = child_info.metrics[MetricsType::OPERATOR_CARDINALITY].GetValue(); - info.MetricSum(MetricsType::OPERATOR_CARDINALITY, value); + auto value = child_info.metrics[MetricType::OPERATOR_CARDINALITY].GetValue(); + info.MetricSum(MetricType::OPERATOR_CARDINALITY, value); } } } @@ -192,50 +195,16 @@ void QueryProfiler::StartExplainAnalyze() { is_explain_analyze = true; } -template -static void AggregateMetric(ProfilingNode &node, MetricsType aggregated_metric, MetricsType child_metric, - const std::function &update_fun) { - auto &info = node.GetProfilingInfo(); - info.metrics[aggregated_metric] = info.metrics[child_metric]; - - for (idx_t i = 0; i < node.GetChildCount(); i++) { - auto child = node.GetChild(i); - AggregateMetric(*child, aggregated_metric, child_metric, update_fun); - - auto &child_info = child->GetProfilingInfo(); - auto value = child_info.GetMetricValue(aggregated_metric); - info.MetricUpdate(aggregated_metric, value, update_fun); - } -} - -template -static void GetCumulativeMetric(ProfilingNode &node, MetricsType cumulative_metric, MetricsType child_metric) { - AggregateMetric( - node, cumulative_metric, child_metric, - [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { return old_value + new_value; }); -} - -Value GetCumulativeOptimizers(ProfilingNode &node) { - auto &metrics = node.GetProfilingInfo().metrics; - double count = 0; - for (auto &metric : metrics) { - if (MetricsUtils::IsOptimizerMetric(metric.first)) { - count += metric.second.GetValue(); - } - } - return Value::CreateValue(count); -} - void QueryProfiler::EndQuery() { unique_lock guard(lock); if (!IsEnabled() || !running) { return; } - query_metrics.latency.End(); + query_metrics.latency_timer->EndTimer(); if (root) { auto &info = root->GetProfilingInfo(); - if (info.Enabled(info.expanded_settings, MetricsType::OPERATOR_CARDINALITY)) { + if (info.Enabled(info.expanded_settings, MetricType::OPERATOR_CARDINALITY)) { Finalize(*root->GetChild(0)); } } @@ -249,42 +218,17 @@ void QueryProfiler::EndQuery() { auto &info = root->GetProfilingInfo(); info = ProfilingInfo(ClientConfig::GetConfig(context).profiler_settings); auto &child_info = root->children[0]->GetProfilingInfo(); - info.metrics[MetricsType::QUERY_NAME] = query_metrics.query; - auto &settings = info.expanded_settings; + const auto &settings = info.expanded_settings; for (const auto &global_info_entry : query_metrics.query_global_info.metrics) { info.metrics[global_info_entry.first] = global_info_entry.second; } - if (info.Enabled(settings, MetricsType::LATENCY)) { - info.metrics[MetricsType::LATENCY] = query_metrics.latency.Elapsed(); - } - if (info.Enabled(settings, MetricsType::TOTAL_BYTES_READ)) { - info.metrics[MetricsType::TOTAL_BYTES_READ] = Value::UBIGINT(query_metrics.total_bytes_read); - } - if (info.Enabled(settings, MetricsType::TOTAL_BYTES_WRITTEN)) { - info.metrics[MetricsType::TOTAL_BYTES_WRITTEN] = Value::UBIGINT(query_metrics.total_bytes_written); - } - if (info.Enabled(settings, MetricsType::ROWS_RETURNED)) { - info.metrics[MetricsType::ROWS_RETURNED] = child_info.metrics[MetricsType::OPERATOR_CARDINALITY]; - } - if (info.Enabled(settings, MetricsType::CPU_TIME)) { - GetCumulativeMetric(*root, MetricsType::CPU_TIME, MetricsType::OPERATOR_TIMING); - } - if (info.Enabled(settings, MetricsType::CUMULATIVE_CARDINALITY)) { - GetCumulativeMetric(*root, MetricsType::CUMULATIVE_CARDINALITY, - MetricsType::OPERATOR_CARDINALITY); - } - if (info.Enabled(settings, MetricsType::CUMULATIVE_ROWS_SCANNED)) { - GetCumulativeMetric(*root, MetricsType::CUMULATIVE_ROWS_SCANNED, - MetricsType::OPERATOR_ROWS_SCANNED); - } - if (info.Enabled(settings, MetricsType::RESULT_SET_SIZE)) { - info.metrics[MetricsType::RESULT_SET_SIZE] = child_info.metrics[MetricsType::RESULT_SET_SIZE]; - } MoveOptimizerPhasesToRoot(); - if (info.Enabled(settings, MetricsType::CUMULATIVE_OPTIMIZER_TIMING)) { - info.metrics.at(MetricsType::CUMULATIVE_OPTIMIZER_TIMING) = GetCumulativeOptimizers(*root); + for (auto &metric : info.metrics) { + if (info.Enabled(settings, metric.first)) { + ProfilingUtils::CollectMetrics(metric.first, query_metrics, metric.second, *root, child_info); + } } } @@ -297,6 +241,9 @@ void QueryProfiler::EndQuery() { guard.unlock(); + // To log is inexpensive, whether to log or not depends on whether logging is active + ToLog(); + if (emit_output) { string tree = ToString(); auto save_location = GetSaveLocation(); @@ -310,16 +257,22 @@ void QueryProfiler::EndQuery() { } } -void QueryProfiler::AddBytesRead(const idx_t nr_bytes) { +void QueryProfiler::AddToCounter(const MetricType type, const idx_t amount) { if (IsEnabled()) { - query_metrics.total_bytes_read += nr_bytes; + query_metrics.UpdateMetric(type, amount); } } -void QueryProfiler::AddBytesWritten(const idx_t nr_bytes) { - if (IsEnabled()) { - query_metrics.total_bytes_written += nr_bytes; - } +idx_t QueryProfiler::GetBytesRead() const { + return query_metrics.GetMetricsIndex(MetricType::TOTAL_BYTES_READ); +} + +idx_t QueryProfiler::GetBytesWritten() const { + return query_metrics.GetMetricsIndex(MetricType::TOTAL_BYTES_WRITTEN); +} + +ActiveTimer QueryProfiler::StartTimer(const MetricType type) { + return ActiveTimer(query_metrics, type, IsEnabled()); } string QueryProfiler::ToString(ExplainFormat explain_format) const { @@ -339,19 +292,16 @@ string QueryProfiler::ToString(ProfilerPrintFormat format) const { case ProfilerPrintFormat::NO_OUTPUT: return ""; case ProfilerPrintFormat::HTML: - case ProfilerPrintFormat::GRAPHVIZ: { + case ProfilerPrintFormat::GRAPHVIZ: + case ProfilerPrintFormat::MERMAID: { lock_guard guard(lock); // checking the tree to ensure the query is really empty // the query string is empty when a logical plan is deserialized - if (query_metrics.query.empty() && !root) { + if (query_metrics.query_name.empty() && !root) { return ""; } auto renderer = TreeRenderer::CreateRenderer(GetExplainFormat(format)); - duckdb::stringstream str; - auto &info = root->GetProfilingInfo(); - if (info.Enabled(info.expanded_settings, MetricsType::OPERATOR_TIMING)) { - info.metrics[MetricsType::OPERATOR_TIMING] = query_metrics.latency.Elapsed(); - } + stringstream str; renderer->Render(*root, str); return str.str(); } @@ -360,7 +310,7 @@ string QueryProfiler::ToString(ProfilerPrintFormat format) const { } } -void QueryProfiler::StartPhase(MetricsType phase_metric) { +void QueryProfiler::StartPhase(MetricType phase_metric) { lock_guard guard(lock); if (!IsEnabled() || !running) { return; @@ -404,7 +354,7 @@ OperatorProfiler::OperatorProfiler(ClientContext &context) : context(context) { } // Reduce. - auto root_metrics = ProfilingInfo::DefaultRootSettings(); + auto root_metrics = MetricsUtils::GetRootScopeMetrics(); for (const auto metric : root_metrics) { settings.erase(metric); } @@ -420,7 +370,7 @@ void OperatorProfiler::StartOperator(optional_ptr phys_o active_operator = phys_op; if (!settings.empty()) { - if (ProfilingInfo::Enabled(settings, MetricsType::EXTRA_INFO)) { + if (ProfilingInfo::Enabled(settings, MetricType::EXTRA_INFO)) { if (!OperatorInfoIsInitialized(*active_operator)) { // first time calling into this operator - fetch the info auto &info = GetOperatorInfo(*active_operator); @@ -430,7 +380,7 @@ void OperatorProfiler::StartOperator(optional_ptr phys_o } // Start the timing of the current operator. - if (ProfilingInfo::Enabled(settings, MetricsType::OPERATOR_TIMING)) { + if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_TIMING)) { op.Start(); } } @@ -446,22 +396,22 @@ void OperatorProfiler::EndOperator(optional_ptr chunk) { if (!settings.empty()) { auto &info = GetOperatorInfo(*active_operator); - if (ProfilingInfo::Enabled(settings, MetricsType::OPERATOR_TIMING)) { + if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_TIMING)) { op.End(); info.AddTime(op.Elapsed()); } - if (ProfilingInfo::Enabled(settings, MetricsType::OPERATOR_CARDINALITY) && chunk) { + if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_CARDINALITY) && chunk) { info.AddReturnedElements(chunk->size()); } - if (ProfilingInfo::Enabled(settings, MetricsType::RESULT_SET_SIZE) && chunk) { + if (ProfilingInfo::Enabled(settings, MetricType::RESULT_SET_SIZE) && chunk) { auto result_set_size = chunk->GetAllocationSize(); info.AddResultSetSize(result_set_size); } - if (ProfilingInfo::Enabled(settings, MetricsType::SYSTEM_PEAK_BUFFER_MEMORY)) { + if (ProfilingInfo::Enabled(settings, MetricType::SYSTEM_PEAK_BUFFER_MEMORY)) { auto used_memory = BufferManager::GetBufferManager(context).GetBufferPool().GetUsedMemory(false); info.UpdateSystemPeakBufferManagerMemory(used_memory); } - if (ProfilingInfo::Enabled(settings, MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { + if (ProfilingInfo::Enabled(settings, MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { auto used_swap = BufferManager::GetBufferManager(context).GetUsedSwap(); info.UpdateSystemPeakTempDirectorySize(used_swap); } @@ -477,7 +427,7 @@ void OperatorProfiler::FinishSource(GlobalSourceState &gstate, LocalSourceState throw InternalException("OperatorProfiler: Attempting to call FinishSource while no operator is active"); } if (!settings.empty()) { - if (ProfilingInfo::Enabled(settings, MetricsType::EXTRA_INFO)) { + if (ProfilingInfo::Enabled(settings, MetricType::EXTRA_INFO)) { // we're emitting extra info - get the extra source info auto &info = GetOperatorInfo(*active_operator); auto extra_info = active_operator->ExtraSourceParams(gstate, lstate); @@ -534,13 +484,13 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { auto &tree_node = entry->second.get(); auto &info = tree_node.GetProfilingInfo(); - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::OPERATOR_TIMING)) { - info.MetricSum(MetricsType::OPERATOR_TIMING, node.second.time); + if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_TIMING)) { + info.MetricSum(MetricType::OPERATOR_TIMING, node.second.time); } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::OPERATOR_CARDINALITY)) { - info.MetricSum(MetricsType::OPERATOR_CARDINALITY, node.second.elements_returned); + if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_CARDINALITY)) { + info.MetricSum(MetricType::OPERATOR_CARDINALITY, node.second.elements_returned); } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::OPERATOR_ROWS_SCANNED)) { + if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_ROWS_SCANNED)) { if (op.type == PhysicalOperatorType::TABLE_SCAN) { auto &scan_op = op.Cast(); auto &bind_data = scan_op.bind_data; @@ -548,38 +498,38 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { if (bind_data && scan_op.function.cardinality) { auto cardinality = scan_op.function.cardinality(context, &(*bind_data)); if (cardinality && cardinality->has_estimated_cardinality) { - info.MetricSum(MetricsType::OPERATOR_ROWS_SCANNED, cardinality->estimated_cardinality); + info.MetricSum(MetricType::OPERATOR_ROWS_SCANNED, cardinality->estimated_cardinality); } } } } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::RESULT_SET_SIZE)) { - info.MetricSum(MetricsType::RESULT_SET_SIZE, node.second.result_set_size); + if (ProfilingInfo::Enabled(profiler.settings, MetricType::RESULT_SET_SIZE)) { + info.MetricSum(MetricType::RESULT_SET_SIZE, node.second.result_set_size); } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = node.second.extra_info; + if (ProfilingInfo::Enabled(profiler.settings, MetricType::EXTRA_INFO)) { + info.metrics[MetricType::EXTRA_INFO] = Value::MAP(node.second.extra_info); } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::SYSTEM_PEAK_BUFFER_MEMORY)) { - query_metrics.query_global_info.MetricMax(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, + if (ProfilingInfo::Enabled(profiler.settings, MetricType::SYSTEM_PEAK_BUFFER_MEMORY)) { + query_metrics.query_global_info.MetricMax(MetricType::SYSTEM_PEAK_BUFFER_MEMORY, node.second.system_peak_buffer_manager_memory); } - if (ProfilingInfo::Enabled(profiler.settings, MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { - query_metrics.query_global_info.MetricMax(MetricsType::SYSTEM_PEAK_TEMP_DIR_SIZE, + if (ProfilingInfo::Enabled(profiler.settings, MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { + query_metrics.query_global_info.MetricMax(MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, node.second.system_peak_temp_directory_size); } } profiler.operator_infos.clear(); } -void QueryProfiler::SetInfo(const double &blocked_thread_time) { +void QueryProfiler::SetBlockedTime(const double &blocked_thread_time) { lock_guard guard(lock); if (!IsEnabled() || !running) { return; } auto &info = root->GetProfilingInfo(); - if (info.Enabled(info.expanded_settings, MetricsType::BLOCKED_THREAD_TIME)) { - query_metrics.query_global_info.metrics[MetricsType::BLOCKED_THREAD_TIME] = blocked_thread_time; + if (info.Enabled(info.expanded_settings, MetricType::BLOCKED_THREAD_TIME)) { + query_metrics.query_global_info.metrics[MetricType::BLOCKED_THREAD_TIME] = blocked_thread_time; } } @@ -656,15 +606,15 @@ void PrintPhaseTimingsToStream(std::ostream &ss, const ProfilingInfo &info, idx_ optimizer_timings[EnumUtil::ToString(entry.first).substr(10)] = entry.second.GetValue(); } else if (MetricsUtils::IsPhaseTimingMetric(entry.first)) { switch (entry.first) { - case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: + case MetricType::CUMULATIVE_OPTIMIZER_TIMING: continue; - case MetricsType::ALL_OPTIMIZERS: + case MetricType::ALL_OPTIMIZERS: optimizer_head = {"Optimizer", entry.second.GetValue()}; break; - case MetricsType::PHYSICAL_PLANNER: + case MetricType::PHYSICAL_PLANNER: physical_planner_head = {"Physical Planner", entry.second.GetValue()}; break; - case MetricsType::PLANNER: + case MetricType::PLANNER: planner_head = {"Planner", entry.second.GetValue()}; break; default: @@ -672,9 +622,9 @@ void PrintPhaseTimingsToStream(std::ostream &ss, const ProfilingInfo &info, idx_ } auto metric = EnumUtil::ToString(entry.first); - if (StringUtil::StartsWith(metric, "PHYSICAL_PLANNER") && entry.first != MetricsType::PHYSICAL_PLANNER) { + if (StringUtil::StartsWith(metric, "PHYSICAL_PLANNER") && entry.first != MetricType::PHYSICAL_PLANNER) { physical_planner_timings[metric.substr(17)] = entry.second.GetValue(); - } else if (StringUtil::StartsWith(metric, "PLANNER") && entry.first != MetricsType::PLANNER) { + } else if (StringUtil::StartsWith(metric, "PLANNER") && entry.first != MetricType::PLANNER) { planner_timings[metric.substr(8)] = entry.second.GetValue(); } } @@ -687,16 +637,23 @@ void PrintPhaseTimingsToStream(std::ostream &ss, const ProfilingInfo &info, idx_ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { lock_guard guard(lock); + + bool show_query_name = false; + if (root) { + auto &info = root->GetProfilingInfo(); + auto &settings = info.expanded_settings; + show_query_name = info.Enabled(settings, MetricType::QUERY_NAME); + } ss << "┌─────────────────────────────────────┐\n"; ss << "│┌───────────────────────────────────┐│\n"; ss << "││ Query Profiling Information ││\n"; ss << "│└───────────────────────────────────┘│\n"; ss << "└─────────────────────────────────────┘\n"; - ss << StringUtil::Replace(query_metrics.query, "\n", " ") + "\n"; + ss << (show_query_name ? StringUtil::Replace(query_metrics.query_name, "\n", " ") : "") + "\n"; // checking the tree to ensure the query is really empty // the query string is empty when a logical plan is deserialized - if (query_metrics.query.empty() && !root) { + if (query_metrics.query_name.empty() && !root) { return; } @@ -707,7 +664,7 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { constexpr idx_t TOTAL_BOX_WIDTH = 50; ss << "┌────────────────────────────────────────────────┐\n"; ss << "│┌──────────────────────────────────────────────┐│\n"; - string total_time = "Total Time: " + RenderTiming(query_metrics.latency.Elapsed()); + string total_time = "Total Time: " + RenderTiming(query_metrics.GetMetricInSeconds(MetricType::LATENCY)); ss << "││" + DrawPadded(total_time, TOTAL_BOX_WIDTH - 4) + "││\n"; ss << "│└──────────────────────────────────────────────┘│\n"; ss << "└────────────────────────────────────────────────┘\n"; @@ -721,18 +678,24 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { } } -InsertionOrderPreservingMap QueryProfiler::JSONSanitize(const InsertionOrderPreservingMap &input) { +Value QueryProfiler::JSONSanitize(const Value &input) { + D_ASSERT(input.type().id() == LogicalTypeId::MAP); + InsertionOrderPreservingMap result; - for (auto &it : input) { - auto key = it.first; + auto children = MapValue::GetChildren(input); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + if (StringUtil::StartsWith(key, "__")) { key = StringUtil::Replace(key, "__", ""); key = StringUtil::Replace(key, "_", " "); key = StringUtil::Title(key); } - result[key] = it.second; + result[key] = value; } - return result; + return Value::MAP(result); } string QueryProfiler::JSONSanitize(const std::string &text) { @@ -772,7 +735,12 @@ string QueryProfiler::JSONSanitize(const std::string &text) { static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) { auto result_obj = yyjson_mut_obj(doc); auto &profiling_info = node.GetProfilingInfo(); - profiling_info.extra_info = QueryProfiler::JSONSanitize(profiling_info.extra_info); + + if (profiling_info.Enabled(profiling_info.settings, MetricType::EXTRA_INFO)) { + profiling_info.metrics[MetricType::EXTRA_INFO] = + QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricType::EXTRA_INFO)); + } + profiling_info.WriteMetricsToJSON(doc, result_obj); auto children_list = yyjson_mut_arr(doc); @@ -784,44 +752,56 @@ static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) return result_obj; } -static string StringifyAndFree(yyjson_mut_doc *doc, yyjson_mut_val *object) { - auto data = yyjson_mut_val_write_opts(object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, - nullptr, nullptr); - if (!data) { - yyjson_mut_doc_free(doc); +static string StringifyAndFree(ConvertedJSONHolder &json_holder, yyjson_mut_val *object) { + json_holder.stringified_json = yyjson_mut_val_write_opts( + object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, nullptr, nullptr); + if (!json_holder.stringified_json) { throw InternalException("The plan could not be rendered as JSON, yyjson failed"); } - auto result = string(data); - free(data); - yyjson_mut_doc_free(doc); + auto result = string(json_holder.stringified_json); return result; } +void QueryProfiler::ToLog() const { + lock_guard guard(lock); + + if (!root) { + // No root, not much to do + return; + } + + auto &settings = root->GetProfilingInfo(); + + settings.WriteMetricsToLog(context); +} + string QueryProfiler::ToJSON() const { lock_guard guard(lock); - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; - if (query_metrics.query.empty() && !root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "empty"); - return StringifyAndFree(doc, result_obj); + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); + + if (query_metrics.query_name.empty() && !root) { + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "empty"); + return StringifyAndFree(json_holder, result_obj); } if (!root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "error"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "error"); + return StringifyAndFree(json_holder, result_obj); } auto &settings = root->GetProfilingInfo(); - settings.WriteMetricsToJSON(doc, result_obj); + settings.WriteMetricsToJSON(json_holder.doc, result_obj); // recursively print the physical operator tree - auto children_list = yyjson_mut_arr(doc); - yyjson_mut_obj_add_val(doc, result_obj, "children", children_list); - auto child = ToJSONRecursive(doc, *root->GetChild(0)); + auto children_list = yyjson_mut_arr(json_holder.doc); + yyjson_mut_obj_add_val(json_holder.doc, result_obj, "children", children_list); + auto child = ToJSONRecursive(json_holder.doc, *root->GetChild(0)); yyjson_mut_arr_add_val(children_list, child); - return StringifyAndFree(doc, result_obj); + return StringifyAndFree(json_holder, result_obj); } void QueryProfiler::WriteToFile(const char *path, string &info) const { @@ -839,7 +819,7 @@ profiler_settings_t EraseQueryRootSettings(profiler_settings_t settings) { for (auto &setting : settings) { if (MetricsUtils::IsOptimizerMetric(setting) || MetricsUtils::IsPhaseTimingMetric(setting) || - MetricsUtils::IsQueryGlobalMetric(setting)) { + MetricsUtils::IsRootScopeMetric(setting)) { phase_timing_settings_to_erase.insert(setting); } } @@ -867,11 +847,11 @@ unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root node->depth = depth; if (depth != 0) { - info.metrics[MetricsType::OPERATOR_NAME] = root_p.GetName(); - info.MetricSum(MetricsType::OPERATOR_TYPE, static_cast(root_p.type)); + info.metrics[MetricType::OPERATOR_NAME] = root_p.GetName(); + info.MetricSum(MetricType::OPERATOR_TYPE, static_cast(root_p.type)); } - if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = root_p.ParamsToString(); + if (info.Enabled(info.settings, MetricType::EXTRA_INFO)) { + info.metrics[MetricType::EXTRA_INFO] = Value::MAP(root_p.ParamsToString()); } tree_map.insert(make_pair(reference(root_p), reference(*node))); @@ -904,13 +884,20 @@ string QueryProfiler::RenderDisabledMessage(ProfilerPrintFormat format) const { node_0_0 [label="Query profiling is disabled. Use 'PRAGMA enable_profiling;' to enable profiling!"]; } )"; + case ProfilerPrintFormat::MERMAID: + return R"(flowchart TD + node_0_0["`**DISABLED** +Query profiling is disabled. +Use 'PRAGMA enable_profiling;' to enable profiling!`"] +)"; case ProfilerPrintFormat::JSON: { - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); - yyjson_mut_obj_add_str(doc, result_obj, "result", "disabled"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "disabled"); + return StringifyAndFree(json_holder, result_obj); } default: throw InternalException("Unknown ProfilerPrintFormat \"%s\"", EnumUtil::ToString(format)); @@ -962,7 +949,4 @@ void QueryProfiler::MoveOptimizerPhasesToRoot() { } } -void QueryProfiler::Propagate(QueryProfiler &) { -} - } // namespace duckdb diff --git a/src/duckdb/src/main/query_result.cpp b/src/duckdb/src/main/query_result.cpp index a20f9a87a..1ee9c47df 100644 --- a/src/duckdb/src/main/query_result.cpp +++ b/src/duckdb/src/main/query_result.cpp @@ -41,7 +41,7 @@ const ExceptionType &BaseQueryResult::GetErrorType() const { return error.Type(); } -const std::string &BaseQueryResult::GetError() { +const std::string &BaseQueryResult::GetError() const { D_ASSERT(HasError()); return error.Message(); } @@ -110,6 +110,10 @@ unique_ptr QueryResult::Fetch() { return chunk; } +unique_ptr QueryResult::FetchRaw() { + return FetchInternal(); +} + bool QueryResult::Equals(QueryResult &other) { // LCOV_EXCL_START // first compare the success state of the results if (success != other.success) { diff --git a/src/duckdb/src/main/relation.cpp b/src/duckdb/src/main/relation.cpp index 9a28349e7..8efa2cd11 100644 --- a/src/duckdb/src/main/relation.cpp +++ b/src/duckdb/src/main/relation.cpp @@ -241,7 +241,12 @@ BoundStatement Relation::Bind(Binder &binder) { } shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { - return make_shared_ptr(shared_from_this(), schema_name, table_name); + return InsertRel(INVALID_CATALOG, schema_name, table_name); +} + +shared_ptr Relation::InsertRel(const string &catalog_name, const string &schema_name, + const string &table_name) { + return make_shared_ptr(shared_from_this(), catalog_name, schema_name, table_name); } void Relation::Insert(const string &table_name) { @@ -249,7 +254,11 @@ void Relation::Insert(const string &table_name) { } void Relation::Insert(const string &schema_name, const string &table_name) { - auto insert = InsertRel(schema_name, table_name); + Insert(INVALID_CATALOG, schema_name, table_name); +} + +void Relation::Insert(const string &catalog_name, const string &schema_name, const string &table_name) { + auto insert = InsertRel(catalog_name, schema_name, table_name); auto res = insert->Execute(); if (res->HasError()) { const string prepended_message = "Failed to insert into table '" + table_name + "': "; @@ -258,30 +267,40 @@ void Relation::Insert(const string &schema_name, const string &table_name) { } void Relation::Insert(const vector> &values) { - vector column_names; - auto rel = make_shared_ptr(context->GetContext(), values, std::move(column_names), "values"); - rel->Insert(GetAlias()); + throw InvalidInputException("INSERT with values can only be used on base tables!"); } void Relation::Insert(vector>> &&expressions) { - vector column_names; - auto rel = make_shared_ptr(context->GetContext(), std::move(expressions), std::move(column_names), - "values"); - rel->Insert(GetAlias()); + (void)std::move(expressions); + throw InvalidInputException("INSERT with expressions can only be used on base tables!"); } shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name, bool temporary, OnCreateConflict on_conflict) { - return make_shared_ptr(shared_from_this(), schema_name, table_name, temporary, on_conflict); + return CreateRel(INVALID_CATALOG, schema_name, table_name, temporary, on_conflict); +} + +shared_ptr Relation::CreateRel(const string &catalog_name, const string &schema_name, + const string &table_name, bool temporary, OnCreateConflict on_conflict) { + return make_shared_ptr(shared_from_this(), catalog_name, schema_name, table_name, temporary, + on_conflict); } void Relation::Create(const string &table_name, bool temporary, OnCreateConflict on_conflict) { - Create(INVALID_SCHEMA, table_name, temporary, on_conflict); + Create(INVALID_CATALOG, INVALID_SCHEMA, table_name, temporary, on_conflict); } void Relation::Create(const string &schema_name, const string &table_name, bool temporary, OnCreateConflict on_conflict) { - auto create = CreateRel(schema_name, table_name, temporary, on_conflict); + Create(INVALID_CATALOG, schema_name, table_name, temporary, on_conflict); +} + +void Relation::Create(const string &catalog_name, const string &schema_name, const string &table_name, bool temporary, + OnCreateConflict on_conflict) { + if (table_name.empty()) { + throw ParserException("Empty table name not supported"); + } + auto create = CreateRel(catalog_name, schema_name, table_name, temporary, on_conflict); auto res = create->Execute(); if (res->HasError()) { const string prepended_message = "Failed to create table '" + table_name + "': "; @@ -337,7 +356,9 @@ unique_ptr Relation::Query(const string &sql) const { } unique_ptr Relation::Query(const string &name, const string &sql) { - CreateView(name); + bool replace = true; + bool temp = IsReadOnly(); + CreateView(name, replace, temp); return Query(sql); } @@ -394,8 +415,8 @@ string Relation::ToString() { } // LCOV_EXCL_START -unique_ptr Relation::GetQueryNode() { - throw InternalException("Cannot create a query node from this node type"); +string Relation::GetQuery() { + return GetQueryNode()->ToString(); } void Relation::Head(idx_t limit) { diff --git a/src/duckdb/src/main/relation/create_table_relation.cpp b/src/duckdb/src/main/relation/create_table_relation.cpp index 2492f244b..2a08194c0 100644 --- a/src/duckdb/src/main/relation/create_table_relation.cpp +++ b/src/duckdb/src/main/relation/create_table_relation.cpp @@ -14,12 +14,21 @@ CreateTableRelation::CreateTableRelation(shared_ptr child_p, string sc TryBindRelation(columns); } +CreateTableRelation::CreateTableRelation(shared_ptr child_p, string catalog_name, string schema_name, + string table_name, bool temporary_p, OnCreateConflict on_conflict) + : Relation(child_p->context, RelationType::CREATE_TABLE_RELATION), child(std::move(child_p)), + catalog_name(std::move(catalog_name)), schema_name(std::move(schema_name)), table_name(std::move(table_name)), + temporary(temporary_p), on_conflict(on_conflict) { + TryBindRelation(columns); +} + BoundStatement CreateTableRelation::Bind(Binder &binder) { auto select = make_uniq(); select->node = child->GetQueryNode(); CreateStatement stmt; auto info = make_uniq(); + info->catalog = catalog_name; info->schema = schema_name; info->table = table_name; info->query = std::move(select); @@ -29,6 +38,14 @@ BoundStatement CreateTableRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateTableRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a create table relation"); +} + +string CreateTableRelation::GetQuery() { + return string(); +} + const vector &CreateTableRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/create_view_relation.cpp b/src/duckdb/src/main/relation/create_view_relation.cpp index c00deef38..6f77f013f 100644 --- a/src/duckdb/src/main/relation/create_view_relation.cpp +++ b/src/duckdb/src/main/relation/create_view_relation.cpp @@ -35,6 +35,14 @@ BoundStatement CreateViewRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateViewRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string CreateViewRelation::GetQuery() { + return string(); +} + const vector &CreateViewRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/delete_relation.cpp b/src/duckdb/src/main/relation/delete_relation.cpp index 64b3f231e..2ec60f664 100644 --- a/src/duckdb/src/main/relation/delete_relation.cpp +++ b/src/duckdb/src/main/relation/delete_relation.cpp @@ -26,6 +26,14 @@ BoundStatement DeleteRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr DeleteRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a delete relation"); +} + +string DeleteRelation::GetQuery() { + return string(); +} + const vector &DeleteRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/explain_relation.cpp b/src/duckdb/src/main/relation/explain_relation.cpp index f91e1d29f..9f2976c9d 100644 --- a/src/duckdb/src/main/relation/explain_relation.cpp +++ b/src/duckdb/src/main/relation/explain_relation.cpp @@ -20,6 +20,14 @@ BoundStatement ExplainRelation::Bind(Binder &binder) { return binder.Bind(explain.Cast()); } +unique_ptr ExplainRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an explain relation"); +} + +string ExplainRelation::GetQuery() { + return string(); +} + const vector &ExplainRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/insert_relation.cpp b/src/duckdb/src/main/relation/insert_relation.cpp index 9728570a0..461133255 100644 --- a/src/duckdb/src/main/relation/insert_relation.cpp +++ b/src/duckdb/src/main/relation/insert_relation.cpp @@ -13,17 +13,32 @@ InsertRelation::InsertRelation(shared_ptr child_p, string schema_name, TryBindRelation(columns); } +InsertRelation::InsertRelation(shared_ptr child_p, string catalog_name, string schema_name, string table_name) + : Relation(child_p->context, RelationType::INSERT_RELATION), child(std::move(child_p)), + catalog_name(std::move(catalog_name)), schema_name(std::move(schema_name)), table_name(std::move(table_name)) { + TryBindRelation(columns); +} + BoundStatement InsertRelation::Bind(Binder &binder) { InsertStatement stmt; auto select = make_uniq(); select->node = child->GetQueryNode(); + stmt.catalog = catalog_name; stmt.schema = schema_name; stmt.table = table_name; stmt.select_statement = std::move(select); return binder.Bind(stmt.Cast()); } +unique_ptr InsertRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an insert relation"); +} + +string InsertRelation::GetQuery() { + return string(); +} + const vector &InsertRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/projection_relation.cpp b/src/duckdb/src/main/relation/projection_relation.cpp index 0577ce73d..50345abf3 100644 --- a/src/duckdb/src/main/relation/projection_relation.cpp +++ b/src/duckdb/src/main/relation/projection_relation.cpp @@ -2,6 +2,7 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp index e0cf2e280..79aa1f981 100644 --- a/src/duckdb/src/main/relation/query_relation.cpp +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -49,6 +49,10 @@ unique_ptr QueryRelation::GetQueryNode() { return std::move(select->node); } +string QueryRelation::GetQuery() { + return query; +} + unique_ptr QueryRelation::GetTableRef() { auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); return std::move(subquery_ref); @@ -61,9 +65,6 @@ BoundStatement QueryRelation::Bind(Binder &binder) { auto result = Relation::Bind(binder); auto &replacements = binder.GetReplacementScans(); if (first_bind) { - auto &query_node = *select_stmt->node; - auto &cte_map = query_node.cte_map; - vector> materialized_ctes; for (auto &kv : replacements) { auto &name = kv.first; auto &tableref = kv.second; @@ -83,29 +84,16 @@ BoundStatement QueryRelation::Bind(Binder &binder) { auto cte_info = make_uniq(); cte_info->query = std::move(select); + auto subquery = make_uniq(std::move(select_stmt), "query_relation"); + auto top_level_select = make_uniq(); + auto top_level_select_node = make_uniq(); + top_level_select_node->select_list.push_back(make_uniq()); + top_level_select_node->from_table = std::move(subquery); + auto &cte_map = top_level_select_node->cte_map; + top_level_select->node = std::move(top_level_select_node); cte_map.map[name] = std::move(cte_info); - - // We can not rely on CTE inlining anymore, so we need to add a materialized CTE node - // to the query node to ensure that the CTE exists - auto &cte_entry = cte_map.map[name]; - auto mat_cte = make_uniq(); - mat_cte->ctename = name; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt->node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = root->cte_map.Copy(); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); + select_stmt = std::move(top_level_select); } - select_stmt->node = std::move(root); } replacements.clear(); binder.SetBindingMode(saved_binding_mode); diff --git a/src/duckdb/src/main/relation/read_json_relation.cpp b/src/duckdb/src/main/relation/read_json_relation.cpp index 2f849597d..6fe7e4a7c 100644 --- a/src/duckdb/src/main/relation/read_json_relation.cpp +++ b/src/duckdb/src/main/relation/read_json_relation.cpp @@ -15,7 +15,6 @@ ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, vec : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {MultiFileReader::CreateValueFromFileList(input)}, std::move(options)), alias(std::move(alias_p)) { - InitializeAlias(input); } @@ -24,7 +23,6 @@ ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, str : TableFunctionRelation(context, auto_detect ? "read_json_auto" : "read_json", {Value(json_file_p)}, std::move(options)), json_file(std::move(json_file_p)), alias(std::move(alias_p)) { - if (alias.empty()) { alias = StringUtil::Split(json_file, ".")[0]; } diff --git a/src/duckdb/src/main/relation/table_relation.cpp b/src/duckdb/src/main/relation/table_relation.cpp index c82ace698..78d5aaaa4 100644 --- a/src/duckdb/src/main/relation/table_relation.cpp +++ b/src/duckdb/src/main/relation/table_relation.cpp @@ -3,6 +3,7 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" #include "duckdb/main/relation/delete_relation.hpp" +#include "duckdb/main/relation/value_relation.hpp" #include "duckdb/main/relation/update_relation.hpp" #include "duckdb/parser/parser.hpp" #include "duckdb/main/client_context.hpp" @@ -87,4 +88,17 @@ void TableRelation::Delete(const string &condition) { del->Execute(); } +void TableRelation::Insert(const vector> &values) { + vector column_names; + auto rel = make_shared_ptr(context->GetContext(), values, std::move(column_names), "values"); + rel->Insert(description->database, description->schema, description->table); +} + +void TableRelation::Insert(vector>> &&expressions) { + vector column_names; + auto rel = make_shared_ptr(context->GetContext(), std::move(expressions), std::move(column_names), + "values"); + rel->Insert(description->database, description->schema, description->table); +} + } // namespace duckdb diff --git a/src/duckdb/src/main/relation/update_relation.cpp b/src/duckdb/src/main/relation/update_relation.cpp index 9176cf2f2..81d85ca89 100644 --- a/src/duckdb/src/main/relation/update_relation.cpp +++ b/src/duckdb/src/main/relation/update_relation.cpp @@ -35,6 +35,14 @@ BoundStatement UpdateRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr UpdateRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string UpdateRelation::GetQuery() { + return string(); +} + const vector &UpdateRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_csv_relation.cpp b/src/duckdb/src/main/relation/write_csv_relation.cpp index 4795c7a51..f77d6f1ee 100644 --- a/src/duckdb/src/main/relation/write_csv_relation.cpp +++ b/src/duckdb/src/main/relation/write_csv_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteCSVRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteCSVRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write CSV relation"); +} + +string WriteCSVRelation::GetQuery() { + return string(); +} + const vector &WriteCSVRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_parquet_relation.cpp b/src/duckdb/src/main/relation/write_parquet_relation.cpp index d6e403618..b1dfdb29f 100644 --- a/src/duckdb/src/main/relation/write_parquet_relation.cpp +++ b/src/duckdb/src/main/relation/write_parquet_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteParquetRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteParquetRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write parquet relation"); +} + +string WriteParquetRelation::GetQuery() { + return string(); +} + const vector &WriteParquetRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/result_set_manager.cpp b/src/duckdb/src/main/result_set_manager.cpp new file mode 100644 index 000000000..d8913b8e8 --- /dev/null +++ b/src/duckdb/src/main/result_set_manager.cpp @@ -0,0 +1,51 @@ +#include "duckdb/main/result_set_manager.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +ManagedResultSet::ManagedResultSet() : valid(false) { +} + +ManagedResultSet::ManagedResultSet(const weak_ptr &db_p, vector> &handles_p) + : valid(true), db(db_p), handles(handles_p) { +} + +bool ManagedResultSet::IsValid() const { + return valid; +} + +shared_ptr ManagedResultSet::GetDatabase() const { + D_ASSERT(IsValid()); + return db.lock(); +} + +vector> &ManagedResultSet::GetHandles() { + D_ASSERT(IsValid()); + return *handles; +} + +ResultSetManager::ResultSetManager(DatabaseInstance &db_p) : db(db_p.shared_from_this()) { +} + +ResultSetManager &ResultSetManager::Get(ClientContext &context) { + return Get(*context.db); +} + +ResultSetManager &ResultSetManager::Get(DatabaseInstance &db_p) { + return db_p.GetResultSetManager(); +} + +ManagedResultSet ResultSetManager::Add(ColumnDataAllocator &allocator) { + lock_guard guard(lock); + auto &handles = *open_results.emplace(allocator, make_uniq>>()).first->second; + return ManagedResultSet(db, handles); +} + +void ResultSetManager::Remove(ColumnDataAllocator &allocator) { + lock_guard guard(lock); + open_results.erase(allocator); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/secret/secret_manager.cpp b/src/duckdb/src/main/secret/secret_manager.cpp index 8788b595b..7b132baf8 100644 --- a/src/duckdb/src/main/secret/secret_manager.cpp +++ b/src/duckdb/src/main/secret/secret_manager.cpp @@ -49,7 +49,10 @@ void SecretManager::Initialize(DatabaseInstance &db) { for (auto &path_ele : path_components) { config.default_secret_path = fs.JoinPath(config.default_secret_path, path_ele); } - config.secret_path = config.default_secret_path; + // Use default path if none has been specified by the user configuration + if (config.secret_path.empty()) { + config.secret_path = config.default_secret_path; + } // Set the defaults for persistent storage config.default_persistent_storage = LOCAL_FILE_STORAGE_NAME; diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 96c3065f2..0bacd69bc 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -78,6 +78,28 @@ Value AllowCommunityExtensionsSetting::GetSetting(const ClientContext &context) return Value::BOOLEAN(config.options.allow_community_extensions); } +//===----------------------------------------------------------------------===// +// Allow Parser Override Extension +//===----------------------------------------------------------------------===// +void AllowParserOverrideExtensionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (!OnGlobalSet(db, config, input)) { + return; + } + config.options.allow_parser_override_extension = input.GetValue(); +} + +void AllowParserOverrideExtensionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (!OnGlobalReset(db, config)) { + return; + } + config.options.allow_parser_override_extension = DBConfigOptions().allow_parser_override_extension; +} + +Value AllowParserOverrideExtensionSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.allow_parser_override_extension); +} + //===----------------------------------------------------------------------===// // Allow Unredacted Secrets //===----------------------------------------------------------------------===// @@ -232,6 +254,13 @@ Value DebugForceExternalSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.force_external); } +//===----------------------------------------------------------------------===// +// Debug Physical Table Scan Execution Strategy +//===----------------------------------------------------------------------===// +void DebugPhysicalTableScanExecutionStrategySetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + //===----------------------------------------------------------------------===// // Debug Verify Vector //===----------------------------------------------------------------------===// @@ -535,6 +564,22 @@ void StorageBlockPrefetchSetting::OnSet(SettingCallbackInfo &info, Value ¶me EnumUtil::FromString(StringValue::Get(parameter)); } +//===----------------------------------------------------------------------===// +// Variant Minimum Shredding Size +//===----------------------------------------------------------------------===// +void VariantMinimumShreddingSize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.variant_minimum_shredding_size = input.GetValue(); +} + +void VariantMinimumShreddingSize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.variant_minimum_shredding_size = DBConfigOptions().variant_minimum_shredding_size; +} + +Value VariantMinimumShreddingSize::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BIGINT(config.options.variant_minimum_shredding_size); +} + //===----------------------------------------------------------------------===// // Zstd Min String Length //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/settings/custom_settings.cpp b/src/duckdb/src/main/settings/custom_settings.cpp index 8e9b491e3..cd7b1b97b 100644 --- a/src/duckdb/src/main/settings/custom_settings.cpp +++ b/src/duckdb/src/main/settings/custom_settings.cpp @@ -14,6 +14,7 @@ #include "duckdb/common/enums/access_mode.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/operator/double_cast_operator.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" @@ -31,9 +32,20 @@ #include "duckdb/storage/storage_manager.hpp" #include "duckdb/logging/logger.hpp" #include "duckdb/logging/log_manager.hpp" +#include "duckdb/common/type_visitor.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" +#include "duckdb/storage/block_allocator.hpp" namespace duckdb { +constexpr const char *LoggingMode::Name; +constexpr const char *LoggingLevel::Name; +constexpr const char *EnableLogging::Name; +constexpr const char *LoggingStorage::Name; +constexpr const char *EnabledLogTypes::Name; +constexpr const char *DisabledLogTypes::Name; +constexpr const char *DisabledFilesystemsSetting::Name; + const string GetDefaultUserAgent() { return StringUtil::Format("duckdb/%s(%s)", DuckDB::LibraryVersion(), DuckDB::Platform()); } @@ -150,6 +162,27 @@ bool AllowCommunityExtensionsSetting::OnGlobalReset(DatabaseInstance *db, DBConf return true; } +//===----------------------------------------------------------------------===// +// Allow Parser Override +//===----------------------------------------------------------------------===// +bool AllowParserOverrideExtensionSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + vector supported_options = {"default", "fallback", "strict", "strict_when_supported"}; + string supported_option_string; + for (const auto &option : supported_options) { + if (StringUtil::CIEquals(new_value, option)) { + return true; + } + } + throw InvalidInputException("Unrecognized value for parser override setting. Valid options are: %s", + StringUtil::Join(supported_options, ", ")); +} + +bool AllowParserOverrideExtensionSetting::OnGlobalReset(DatabaseInstance *db, DBConfig &config) { + config.options.allow_parser_override_extension = "default"; + return true; +} + //===----------------------------------------------------------------------===// // Allow Persistent Secrets //===----------------------------------------------------------------------===// @@ -285,6 +318,41 @@ Value AllowedPathsSetting::GetSetting(const ClientContext &context) { return Value::LIST(LogicalType::VARCHAR, std::move(allowed_paths)); } +//===----------------------------------------------------------------------===// +// Block Allocator Memory +//===----------------------------------------------------------------------===// +void BlockAllocatorMemorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + const auto input_string = input.ToString(); + idx_t size; + if (!input_string.empty() && input_string.back() == '%') { + double percentage; + if (!TryDoubleCast(input_string.c_str(), input_string.size() - 1, percentage, false) || percentage < 0 || + percentage > 100) { + throw InvalidInputException("Unable to parse valid percentage (input: %s)", input_string); + } + size = LossyNumericCast(percentage) * config.options.maximum_memory / 100; + } else { + size = DBConfig::ParseMemoryLimit(input_string); + } + if (db) { + BlockAllocator::Get(*db).Resize(size); + } + config.options.block_allocator_size = size; +} + +void BlockAllocatorMemorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + const auto size = DBConfigOptions().block_allocator_size; + if (db) { + BlockAllocator::Get(*db).Resize(size); + } + config.options.block_allocator_size = size; +} + +Value BlockAllocatorMemorySetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return StringUtil::BytesToHumanReadableString(config.options.block_allocator_size); +} + //===----------------------------------------------------------------------===// // Checkpoint Threshold //===----------------------------------------------------------------------===// @@ -301,7 +369,7 @@ Value CheckpointThresholdSetting::GetSetting(const ClientContext &context) { //===----------------------------------------------------------------------===// // Custom Profiling Settings //===----------------------------------------------------------------------===// -bool IsEnabledOptimizer(MetricsType metric, const set &disabled_optimizers) { +bool IsEnabledOptimizer(MetricType metric, const set &disabled_optimizers) { auto matching_optimizer_type = MetricsUtils::GetOptimizerTypeByMetric(metric); if (matching_optimizer_type != OptimizerType::INVALID && disabled_optimizers.find(matching_optimizer_type) == disabled_optimizers.end()) { @@ -310,22 +378,39 @@ bool IsEnabledOptimizer(MetricsType metric, const set &disabled_o return false; } -static profiler_settings_t FillTreeNodeSettings(unordered_map &json, +static profiler_settings_t FillTreeNodeSettings(unordered_map &input, const set &disabled_optimizers) { profiler_settings_t metrics; string invalid_settings; - for (auto &entry : json) { - MetricsType setting; + for (auto &entry : input) { + MetricType setting; + MetricGroup group = MetricGroup::INVALID; try { - setting = EnumUtil::FromString(StringUtil::Upper(entry.first)); + setting = EnumUtil::FromString(StringUtil::Upper(entry.first)); } catch (std::exception &ex) { - if (!invalid_settings.empty()) { - invalid_settings += ", "; + try { + group = EnumUtil::FromString(StringUtil::Upper(entry.first)); + } catch (std::exception &ex) { + if (!invalid_settings.empty()) { + invalid_settings += ", "; + } + invalid_settings += entry.first; + continue; + } + } + if (group != MetricGroup::INVALID) { + if (entry.second == "true") { + auto group_metrics = MetricsUtils::GetMetricsByGroupType(group); + for (auto &metric : group_metrics) { + if (!MetricsUtils::IsOptimizerMetric(metric) || IsEnabledOptimizer(metric, disabled_optimizers)) { + metrics.insert(metric); + } + } } - invalid_settings += entry.first; continue; } + if (StringUtil::Lower(entry.second) == "true" && (!MetricsUtils::IsOptimizerMetric(setting) || IsEnabledOptimizer(setting, disabled_optimizers))) { metrics.insert(setting); @@ -339,7 +424,7 @@ static profiler_settings_t FillTreeNodeSettings(unordered_map &j } void AddOptimizerMetrics(profiler_settings_t &settings, const set &disabled_optimizers) { - if (settings.find(MetricsType::ALL_OPTIMIZERS) != settings.end()) { + if (settings.find(MetricType::ALL_OPTIMIZERS) != settings.end()) { auto optimizer_metrics = MetricsUtils::GetOptimizerMetrics(); for (auto &metric : optimizer_metrics) { if (IsEnabledOptimizer(metric, disabled_optimizers)) { @@ -353,9 +438,9 @@ void CustomProfilingSettingsSetting::SetLocal(ClientContext &context, const Valu auto &config = ClientConfig::GetConfig(context); // parse the file content - unordered_map json; + unordered_map input_json; try { - json = StringUtil::ParseJSONMap(input.ToString())->Flatten(); + input_json = StringUtil::ParseJSONMap(input.ToString())->Flatten(); } catch (std::exception &ex) { throw IOException("Could not parse the custom profiler settings file due to incorrect JSON: \"%s\". Make sure " "all the keys and values start with a quote. ", @@ -366,7 +451,7 @@ void CustomProfilingSettingsSetting::SetLocal(ClientContext &context, const Valu auto &db_config = DBConfig::GetConfig(context); auto &disabled_optimizers = db_config.options.disabled_optimizers; - auto settings = FillTreeNodeSettings(json, disabled_optimizers); + auto settings = FillTreeNodeSettings(input_json, disabled_optimizers); AddOptimizerMetrics(settings, disabled_optimizers); config.profiler_settings = settings; } @@ -374,7 +459,7 @@ void CustomProfilingSettingsSetting::SetLocal(ClientContext &context, const Valu void CustomProfilingSettingsSetting::ResetLocal(ClientContext &context) { auto &config = ClientConfig::GetConfig(context); config.enable_profiler = ClientConfig().enable_profiler; - config.profiler_settings = ProfilingInfo::DefaultSettings(); + config.profiler_settings = MetricsUtils::GetDefaultMetrics(); } Value CustomProfilingSettingsSetting::GetSetting(const ClientContext &context) { @@ -627,6 +712,8 @@ bool EnableExternalAccessSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &co for (auto &path : attached_paths) { config.AddAllowedPath(path); config.AddAllowedPath(path + ".wal"); + config.AddAllowedPath(path + ".checkpoint.wal"); + config.AddAllowedPath(path + ".recovery.wal"); } } if (config.options.use_temporary_directory && !config.options.temporary_directory.empty()) { @@ -681,6 +768,110 @@ void EnableLogging::ResetGlobal(DatabaseInstance *db_p, DBConfig &config) { db.GetLogManager().SetEnableLogging(false); } +//===----------------------------------------------------------------------===// +// Force VARIANT Shredding +//===----------------------------------------------------------------------===// + +void ForceVariantShredding::SetGlobal(DatabaseInstance *_, DBConfig &config, const Value &value) { + auto &force_variant_shredding = config.options.force_variant_shredding; + + if (value.type().id() != LogicalTypeId::VARCHAR) { + throw InvalidInputException("The argument to 'force_variant_shredding' should be of type VARCHAR, not %s", + value.type().ToString()); + } + + auto logical_type = TransformStringToLogicalType(value.GetValue()); + TypeVisitor::Contains(logical_type, [](const LogicalType &type) { + if (type.IsNested()) { + if (type.id() != LogicalTypeId::STRUCT && type.id() != LogicalTypeId::LIST) { + throw InvalidInputException("Shredding can consist of the nested types LIST (for ARRAY Variant values) " + "or STRUCT (for OBJECT Variant values), not %s", + type.ToString()); + } + if (type.id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(type)) { + throw InvalidInputException("STRUCT types in the shredding can not be empty"); + } + return false; + } + switch (type.id()) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_TZ: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: + case LogicalTypeId::UUID: + case LogicalTypeId::BIGNUM: + case LogicalTypeId::TIME_NS: + case LogicalTypeId::INTERVAL: + case LogicalTypeId::BIT: + case LogicalTypeId::GEOMETRY: + break; + default: + throw InvalidInputException("Variants can not be shredded on type: %s", type.ToString()); + } + return false; + }); + + auto shredding_type = TypeVisitor::VisitReplace(logical_type, [](const LogicalType &type) { + return LogicalType::STRUCT({{"untyped_value_index", LogicalType::UINTEGER}, {"typed_value", type}}); + }); + force_variant_shredding = + LogicalType::STRUCT({{"unshredded", VariantShredding::GetUnshreddedType()}, {"shredded", shredding_type}}); +} + +void ForceVariantShredding::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.force_variant_shredding = LogicalType::INVALID; +} + +Value ForceVariantShredding::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.force_variant_shredding.ToString()); +} + +//===----------------------------------------------------------------------===// +// Extension Directory +//===----------------------------------------------------------------------===// +void ExtensionDirectoriesSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.extension_directories.clear(); + + auto &list = ListValue::GetChildren(input); + for (auto &val : list) { + config.options.extension_directories.emplace_back(val.GetValue()); + } +} + +void ExtensionDirectoriesSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.extension_directories = DBConfigOptions().extension_directories; +} + +Value ExtensionDirectoriesSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + vector extension_directories; + for (auto &dir : config.options.extension_directories) { + extension_directories.emplace_back(dir); + } + return Value::LIST(LogicalType::VARCHAR, std::move(extension_directories)); +} + //===----------------------------------------------------------------------===// // Logging Mode //===----------------------------------------------------------------------===// @@ -792,6 +983,17 @@ void EnableProfilingSetting::SetLocal(ClientContext &context, const Value &input config.enable_profiler = true; config.emit_profiler_output = true; + if (parameter != "no_output" && !config.profiler_save_location.empty()) { + auto &file_system = FileSystem::GetFileSystem(context); + const auto file_type = file_system.ExtractExtension(config.profiler_save_location); + if (file_type != parameter && file_type != "txt") { + throw ParserException( + "Profiler file type (%s) must either have the same file extension as the profiling output type (%s), " + "or be a '.txt' file. Set 'profiling_output' to a '%s' file or run \"RESET profiling_output\" first.", + config.profiler_save_location, parameter, parameter); + } + } + if (parameter == "json") { config.profiler_print_format = ProfilerPrintFormat::JSON; } else if (parameter == "query_tree") { @@ -818,9 +1020,9 @@ void EnableProfilingSetting::SetLocal(ClientContext &context, const Value &input } else if (parameter == "graphviz") { config.profiler_print_format = ProfilerPrintFormat::GRAPHVIZ; } else { - throw ParserException( - "Unrecognized print format %s, supported formats: [json, query_tree, query_tree_optimizer, no_output]", - parameter); + throw ParserException("Unrecognized print format %s, supported formats: [json, query_tree, " + "query_tree_optimizer, no_output, html, graphviz]", + parameter); } } @@ -962,9 +1164,15 @@ void ForceCompressionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, } else { auto compression_type = CompressionTypeFromString(compression); //! FIXME: do we want to try to retrieve the AttachedDatabase here to get the StorageManager ?? - if (CompressionTypeIsDeprecated(compression_type)) { - throw ParserException("Attempted to force a deprecated compression type (%s)", - CompressionTypeToString(compression_type)); + auto compression_availability_result = CompressionTypeIsAvailable(compression_type); + if (!compression_availability_result.IsAvailable()) { + if (compression_availability_result.IsDeprecated()) { + throw ParserException("Attempted to force a deprecated compression type (%s)", + CompressionTypeToString(compression_type)); + } else { + throw ParserException("Attempted to force a compression type that isn't available yet (%s)", + CompressionTypeToString(compression_type)); + } } if (compression_type == CompressionType::COMPRESSION_AUTO) { auto compression_types = StringUtil::Join(ListCompressionTypes(), ", "); @@ -1206,6 +1414,26 @@ void PerfectHtThresholdSetting::OnSet(SettingCallbackInfo &info, Value &input) { void ProfileOutputSetting::SetLocal(ClientContext &context, const Value &input) { auto &config = ClientConfig::GetConfig(context); auto parameter = input.ToString(); + + if (!parameter.empty() && config.profiler_print_format != ProfilerPrintFormat::NO_OUTPUT) { + auto &file_system = FileSystem::GetFileSystem(context); + const auto file_type = file_system.ExtractExtension(parameter); + if (file_type != "txt") { + try { + EnumUtil::FromString(file_type); + } catch (std::exception &e) { + throw ParserException("Invalid output file type: %s", file_type); + } + } + + const auto printer_format = StringUtil::Lower(EnumUtil::ToString(config.profiler_print_format)); + if (file_type != printer_format && file_type != "txt") { + throw ParserException("Profiler file type (%s) must either have the same file extension as the profiling " + "output type (%s), or be a '.txt' file. Set \"enable_profiling = \'%s\'\" first.", + parameter, printer_format, file_type); + } + } + config.profiler_save_location = parameter; } @@ -1242,6 +1470,12 @@ void ProfilingModeSetting::SetLocal(ClientContext &context, const Value &input) for (auto &setting : phase_timing_settings) { config.profiler_settings.insert(setting); } + } else if (parameter == "all") { + config.enable_profiler = true; + auto all_metrics = MetricsUtils::GetAllMetrics(); + for (auto &metric : all_metrics) { + config.profiler_settings.insert(metric); + } } else { throw ParserException("Unrecognized profiling mode \"%s\", supported formats: [standard, detailed]", parameter); } diff --git a/src/duckdb/src/main/stream_query_result.cpp b/src/duckdb/src/main/stream_query_result.cpp index 9e4b06caa..676791e2f 100644 --- a/src/duckdb/src/main/stream_query_result.cpp +++ b/src/duckdb/src/main/stream_query_result.cpp @@ -69,7 +69,7 @@ static bool ExecutionErrorOccurred(StreamExecutionResult result) { return false; } -unique_ptr StreamQueryResult::FetchInternal(ClientContextLock &lock) { +unique_ptr StreamQueryResult::FetchNextInternal(ClientContextLock &lock) { bool invalidate_query = true; unique_ptr chunk; try { @@ -106,12 +106,12 @@ unique_ptr StreamQueryResult::FetchInternal(ClientContextLock &lock) return nullptr; } -unique_ptr StreamQueryResult::FetchRaw() { +unique_ptr StreamQueryResult::FetchInternal() { unique_ptr chunk; { auto lock = LockContext(); CheckExecutableInternal(*lock); - chunk = FetchInternal(*lock); + chunk = FetchNextInternal(*lock); } if (!chunk || chunk->ColumnCount() == 0 || chunk->size() == 0) { Close(); diff --git a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp index 6e00355e0..7098caae4 100644 --- a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp +++ b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp @@ -8,6 +8,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_order.hpp" #include "duckdb/planner/operator/logical_projection.hpp" @@ -142,9 +143,31 @@ void ColumnLifetimeAnalyzer::VisitOperator(LogicalOperator &op) { return; } case LogicalOperatorType::LOGICAL_DISTINCT: { - // distinct, all projected columns are used for the DISTINCT computation - // mark all columns as used and continue to the children - // FIXME: DISTINCT with expression list does not implicitly reference everything + // DISTINCT ON only references the expressions specified in the target list (and optional ORDER BY), + auto &distinct = op.Cast(); + if (distinct.distinct_type == DistinctType::DISTINCT_ON) { + auto add_bindings = [&](Expression &expr) { + vector bindings; + ExtractColumnBindings(expr, bindings); + for (auto &binding : bindings) { + column_references.insert(binding); + } + }; + for (auto &target : distinct.distinct_targets) { + if (target) { + add_bindings(*target); + } + } + if (distinct.order_by) { + for (auto &order : distinct.order_by->orders) { + if (order.expression) { + add_bindings(*order.expression); + } + } + } + break; + } + // DISTINCT without targets references the entire projection list everything_referenced = true; break; } diff --git a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp new file mode 100644 index 000000000..5004acca0 --- /dev/null +++ b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp @@ -0,0 +1,604 @@ +#include "duckdb/optimizer/common_subplan_optimizer.hpp" + +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/optimizer/cte_inlining.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Subplan Signature/Info +//===--------------------------------------------------------------------===// +enum class ConversionType { + TO_CANONICAL, + RESTORE_ORIGINAL, +}; + +class PlanSignatureCreateState { +public: + PlanSignatureCreateState() : stream(DEFAULT_BLOCK_ALLOC_SIZE), serializer(stream) { + } + +public: + void Initialize(LogicalOperator &op) { + to_canonical.clear(); + from_canonical.clear(); + table_indices.clear(); + expression_info.clear(); + + for (const auto &child_op : op.children) { + for (const auto &child_cb : child_op->GetColumnBindings()) { + const auto &original = child_cb.table_index; + auto it = to_canonical.find(original); + if (it != to_canonical.end()) { + continue; // We've seen this table index before + } + const auto canonical = CANONICAL_TABLE_INDEX_OFFSET + to_canonical.size(); + to_canonical[original] = canonical; + from_canonical[canonical] = original; + } + } + } + + template + bool Convert(LogicalOperator &op) { + switch (TYPE) { + case ConversionType::TO_CANONICAL: + D_ASSERT(children.empty()); + children = std::move(op.children); + break; + case ConversionType::RESTORE_ORIGINAL: + D_ASSERT(op.children.empty()); + op.children = std::move(children); + break; + } + ConvertTableIndices(op); + return ConvertExpressions(op); + } + +private: + template + void ConvertTableIndices(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_CTE_REF: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_PROJECTION: + ConvertTableIndex(op.Cast().table_index, 0); + break; + case LogicalOperatorType::LOGICAL_PIVOT: + ConvertTableIndex(op.Cast().pivot_index, 0); + break; + case LogicalOperatorType::LOGICAL_UNNEST: + ConvertTableIndex(op.Cast().unnest_index, 0); + break; + case LogicalOperatorType::LOGICAL_WINDOW: + ConvertTableIndex(op.Cast().window_index, 0); + break; + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggr = op.Cast(); + ConvertTableIndex(aggr.group_index, 0); + ConvertTableIndex(aggr.aggregate_index, 1); + ConvertTableIndex(aggr.groupings_index, 2); + break; + } + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + ConvertTableIndex(op.Cast().table_index, 0); + break; + default: + break; + } + } + + template + void ConvertTableIndex(idx_t &table_index, const idx_t i) { + switch (TYPE) { + case ConversionType::TO_CANONICAL: + D_ASSERT(table_indices.size() == i); + table_indices.emplace_back(table_index); + table_index = CANONICAL_TABLE_INDEX_OFFSET + i; + break; + case ConversionType::RESTORE_ORIGINAL: + table_index = table_indices[i]; + break; + } + } + + template + bool ConvertExpressions(LogicalOperator &op) { + const auto &table_index_mapping = TYPE == ConversionType::TO_CANONICAL ? to_canonical : from_canonical; + bool can_materialize = true; + idx_t info_idx = 0; + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *expr) { + ExpressionIterator::EnumerateExpression(*expr, [&](unique_ptr &child) { + if (child->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { + auto &col_ref = child->Cast(); + auto &table_index = col_ref.binding.table_index; + auto it = table_index_mapping.find(table_index); + D_ASSERT(it != table_index_mapping.end()); + table_index = it->second; + } + switch (TYPE) { + case ConversionType::TO_CANONICAL: + expression_info.emplace_back(std::move(child->alias), child->query_location); + child->alias.clear(); + child->query_location.SetInvalid(); + break; + case ConversionType::RESTORE_ORIGINAL: + auto &info = expression_info[info_idx++]; + child->alias = std::move(info.first); + child->query_location = info.second; + break; + } + if (child->IsVolatile()) { + can_materialize = false; + } + }); + }); + return can_materialize; + } + +private: + static constexpr idx_t CANONICAL_TABLE_INDEX_OFFSET = 10000000000000; + +public: + MemoryStream stream; + BinarySerializer serializer; + + //! Mapping from original table index to canonical table index (and reverse mapping) + unordered_map to_canonical; + unordered_map from_canonical; + + //! Place to temporarily store children + vector> children; + + //! Utility vectors to temporarily store table indices and expression info + vector table_indices; + vector> expression_info; +}; + +class PlanSignature { +private: + PlanSignature(const MemoryStream &stream_p, idx_t offset_p, idx_t length_p, + vector> &&child_signatures_p, idx_t operator_count_p, + idx_t base_table_count_p, idx_t max_base_table_cardinality_p) + : stream(stream_p), offset(offset_p), length(length_p), + signature_hash(Hash(stream_p.GetData() + offset, length)), child_signatures(std::move(child_signatures_p)), + operator_count(operator_count_p), base_table_count(base_table_count_p), + max_base_table_cardinality(max_base_table_cardinality_p) { + } + +public: + static unique_ptr Create(PlanSignatureCreateState &state, LogicalOperator &op, + vector> &&child_signatures) { + if (!OperatorIsSupported(op)) { + return nullptr; + } + state.Initialize(op); + + auto can_materialize = state.Convert(op); + + // Serialize canonical representation of operator + const auto offset = state.stream.GetPosition(); + state.serializer.Begin(); + try { // Operators will throw if they cannot serialize, so we need to try/catch here + op.Serialize(state.serializer); + } catch (std::exception &) { + can_materialize = false; + } + state.serializer.End(); + const auto length = state.stream.GetPosition() - offset; + + // Convert back from canonical + state.Convert(op); + + if (can_materialize) { + idx_t operator_count = 1; + idx_t base_table_count = 0; + idx_t max_base_table_cardinality = 0; + if (op.children.empty()) { + base_table_count++; + if (op.has_estimated_cardinality) { + max_base_table_cardinality = op.estimated_cardinality; + } + } + for (auto &child_signature : child_signatures) { + operator_count += child_signature.get().OperatorCount(); + base_table_count += child_signature.get().BaseTableCount(); + max_base_table_cardinality = + MaxValue(max_base_table_cardinality, child_signature.get().MaxBaseTableCardinality()); + } + return unique_ptr(new PlanSignature(state.stream, offset, length, + std::move(child_signatures), operator_count, + base_table_count, max_base_table_cardinality)); + } + return nullptr; + } + + idx_t OperatorCount() const { + return operator_count; + } + + idx_t BaseTableCount() const { + return base_table_count; + } + + idx_t MaxBaseTableCardinality() const { + return max_base_table_cardinality; + } + + hash_t HashSignature() const { + auto res = signature_hash; + for (auto &child : child_signatures) { + res = CombineHash(res, child.get().HashSignature()); + } + return res; + } + + bool Equals(const PlanSignature &other) const { + if (this->GetSignature() != other.GetSignature()) { + return false; + } + if (this->child_signatures.size() != other.child_signatures.size()) { + return false; + } + for (idx_t child_idx = 0; child_idx < this->child_signatures.size(); ++child_idx) { + if (!this->child_signatures[child_idx].get().Equals(other.child_signatures[child_idx].get())) { + return false; + } + } + return true; + } + +private: + String GetSignature() const { + return String(char_ptr_cast(stream.GetData() + offset), NumericCast(length)); + } + + static bool OperatorIsSupported(const LogicalOperator &op) { + if (!op.SupportSerialization()) { + return false; + } + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_FILTER: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_WINDOW: + case LogicalOperatorType::LOGICAL_UNNEST: + case LogicalOperatorType::LOGICAL_LIMIT: + case LogicalOperatorType::LOGICAL_ORDER_BY: + case LogicalOperatorType::LOGICAL_TOP_N: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_PIVOT: + case LogicalOperatorType::LOGICAL_GET: + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + return true; + case LogicalOperatorType::LOGICAL_CHUNK_GET: + // Avoid serializing massive amounts of data (this is here because of the "Test TPCH arrow roundtrip" test) + return op.Cast().collection->Count() < 1000; + default: + // Unsupported: + // - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + // - case LogicalOperatorType::LOGICAL_SAMPLE: + // - case LogicalOperatorType::LOGICAL_COPY_DATABASE: + // - case LogicalOperatorType::LOGICAL_DELIM_GET: + // - case LogicalOperatorType::LOGICAL_CTE_REF: + // - case LogicalOperatorType::LOGICAL_JOIN: + // - case LogicalOperatorType::LOGICAL_DELIM_JOIN: + // - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + // - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + // - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + // - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR + return false; + } + } + +private: + const MemoryStream &stream; + const idx_t offset; + const idx_t length; + + const hash_t signature_hash; + + const vector> child_signatures; + const idx_t operator_count; + const idx_t base_table_count; + const idx_t max_base_table_cardinality; +}; + +struct PlanSignatureHash { + std::size_t operator()(const PlanSignature &k) const { + return k.HashSignature(); + } +}; + +struct PlanSignatureEquality { + bool operator()(const PlanSignature &a, const PlanSignature &b) const { + return a.Equals(b); + } +}; + +struct SubplanInfo { + explicit SubplanInfo(unique_ptr &op) : subplans({op}), lowest_common_ancestor(op) { + } + vector>> subplans; + reference> lowest_common_ancestor; +}; + +using subplan_map_t = unordered_map, SubplanInfo, PlanSignatureHash, PlanSignatureEquality>; + +//===--------------------------------------------------------------------===// +// CommonSubplanFinder +//===--------------------------------------------------------------------===// +class CommonSubplanFinder { +public: + CommonSubplanFinder() { + } + +private: + struct OperatorInfo { + OperatorInfo(unique_ptr &parent_p, const idx_t &depth_p) : parent(parent_p), depth(depth_p) { + } + + unique_ptr &parent; + const idx_t depth; + unique_ptr signature; + }; + + struct StackNode { + explicit StackNode(unique_ptr &op_p) : op(op_p), child_index(0) { + } + + bool HasMoreChildren() const { + return child_index < op->children.size(); + } + + unique_ptr &GetNextChild() { + D_ASSERT(child_index < op->children.size()); + return op->children[child_index++]; + }; + + unique_ptr &op; + idx_t child_index; + }; + +public: + subplan_map_t FindCommonSubplans(reference> root) { + // Find first operator with more than 1 child + while (root.get()->children.size() == 1) { + root = root.get()->children[0]; + } + + // Recurse through query plan using stack-based recursion + vector stack; + stack.emplace_back(root); + operator_infos.emplace(root, OperatorInfo(root, 0)); + + while (!stack.empty()) { + auto ¤t = stack.back(); + + // Depth-first + if (current.HasMoreChildren()) { + auto &child = current.GetNextChild(); + operator_infos.emplace(child, OperatorInfo(current.op, stack.size())); + stack.emplace_back(child); + continue; + } + + if (!RefersToSameObject(current.op, root.get())) { + // We have all child information for this operator now, compute signature + auto &signature = operator_infos.find(current.op)->second.signature; + signature = CreatePlanSignature(current.op); + + // Add to subplans (if we got actually got a signature) + if (signature) { + auto it = subplans.find(*signature); + if (it == subplans.end()) { + subplans.emplace(*signature, SubplanInfo(current.op)); + } else { + auto &info = it->second; + info.subplans.emplace_back(current.op); + info.lowest_common_ancestor = LowestCommonAncestor(info.lowest_common_ancestor, current.op); + } + } + } + + // Done with current + stack.pop_back(); + } + + // Filter out redundant or ineligible subplans before returning + for (auto it = subplans.begin(); it != subplans.end();) { + if (it->first.get().OperatorCount() == 1) { + it = subplans.erase(it); // Just one operator in this subplan + continue; + } + if (it->second.subplans.size() == 1) { + it = subplans.erase(it); // No other identical subplan + continue; + } + auto &subplan = it->second.subplans[0].get(); + auto &parent = operator_infos.find(subplan)->second.parent; + auto &parent_signature = operator_infos.find(parent)->second.signature; + if (parent_signature) { + auto parent_it = subplans.find(*parent_signature); + if (parent_it != subplans.end() && it->second.subplans.size() == parent_it->second.subplans.size()) { + it = subplans.erase(it); // Parent has exact same number of identical subplans + continue; + } + } + if (CTEInlining::EndsInAggregateOrDistinct(*subplan) || IsSelectiveMultiTablePlan(subplan)) { + it++; // This subplan might be useful + } else { + it = subplans.erase(it); // Not eligible for materialization + } + } + + return std::move(subplans); + } + +private: + unique_ptr CreatePlanSignature(const unique_ptr &op) { + vector> child_signatures; + for (auto &child : op->children) { + auto it = operator_infos.find(child); + D_ASSERT(it != operator_infos.end()); + if (!it->second.signature) { + return nullptr; // Failed to create signature from one of the children + } + child_signatures.emplace_back(*it->second.signature); + } + return PlanSignature::Create(state, *op, std::move(child_signatures)); + } + + unique_ptr &LowestCommonAncestor(reference> a, + reference> b) { + auto a_it = operator_infos.find(a); + auto b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + + // Get parents of a and b until they're at the same depth + while (a_it->second.depth > b_it->second.depth) { + a = a_it->second.parent; + a_it = operator_infos.find(a); + D_ASSERT(a_it != operator_infos.end()); + } + while (b_it->second.depth > a_it->second.depth) { + b = b_it->second.parent; + b_it = operator_infos.find(b); + D_ASSERT(b_it != operator_infos.end()); + } + + // Move up one level at a time for both until ancestor is the same + while (!RefersToSameObject(a, b)) { + a_it = operator_infos.find(a); + b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + a = a_it->second.parent; + b = b_it->second.parent; + } + + return a.get(); + } + + bool IsSelectiveMultiTablePlan(unique_ptr &op) const { + static constexpr idx_t CARDINALITY_RATIO = 2; + + if (!op->has_estimated_cardinality) { + return false; + } + const auto &signature = *operator_infos.find(op)->second.signature; + if (signature.BaseTableCount() <= 1) { + return false; + } + return op->estimated_cardinality < signature.MaxBaseTableCardinality() / CARDINALITY_RATIO; + } + +private: + //! Mapping from operator to info + reference_map_t, OperatorInfo> operator_infos; + //! Mapping from subplan signature to subplan information + subplan_map_t subplans; + //! State for creating PlanSignature with reusable data structures + PlanSignatureCreateState state; +}; + +//===--------------------------------------------------------------------===// +// CommonSubplanOptimizer +//===--------------------------------------------------------------------===// +CommonSubplanOptimizer::CommonSubplanOptimizer(Optimizer &optimizer_p) : optimizer(optimizer_p) { +} + +static void ConvertSubplansToCTE(Optimizer &optimizer, unique_ptr &op, SubplanInfo &subplan_info) { + const auto cte_index = optimizer.binder.GenerateTableIndex(); + const auto cte_name = StringUtil::Format("__common_subplan_1"); + + // Resolve types to be used for creating the materialized CTE and refs + op->ResolveOperatorTypes(); + + // Get types and names + const auto &types = subplan_info.subplans[0].get()->types; + vector col_names; + for (idx_t i = 0; i < types.size(); i++) { + col_names.emplace_back(StringUtil::Format("%s_col_%llu", cte_name, i)); + } + + // Create CTE refs and figure out column binding replacements + vector> cte_refs; + ColumnBindingReplacer replacer; + for (auto &subplan : subplan_info.subplans) { + cte_refs.emplace_back( + make_uniq(optimizer.binder.GenerateTableIndex(), cte_index, types, col_names)); + const auto old_bindings = subplan.get()->GetColumnBindings(); + const auto new_bindings = cte_refs.back()->GetColumnBindings(); + D_ASSERT(old_bindings.size() == new_bindings.size()); + for (idx_t i = 0; i < old_bindings.size(); i++) { + replacer.replacement_bindings.emplace_back(old_bindings[i], new_bindings[i]); + } + } + + // Create the materialized CTE and replace the common subplans with references to it + auto &lowest_common_ancestor = subplan_info.lowest_common_ancestor.get(); + auto cte = + make_uniq(cte_name, cte_index, types.size(), std::move(subplan_info.subplans[0].get()), + std::move(lowest_common_ancestor), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + for (idx_t i = 0; i < subplan_info.subplans.size(); i++) { + subplan_info.subplans[i].get() = std::move(cte_refs[i]); + } + lowest_common_ancestor = std::move(cte); + + // Replace bindings of subplans with those of the CTE refs + replacer.stop_operator = lowest_common_ancestor.get(); + replacer.VisitOperator(*op); // Replace from the root until CTE + replacer.VisitOperator(*lowest_common_ancestor->children[1]); // Replace in CTE child +} + +unique_ptr CommonSubplanOptimizer::Optimize(unique_ptr op) { + // Bottom-up identification of identical subplans + CommonSubplanFinder finder; + auto subplans = finder.FindCommonSubplans(op); + + // Identify the single best subplan (TODO: for now, in the future we should identify multiple) + if (subplans.empty()) { + return op; // No eligible subplans + } + auto best_it = subplans.begin(); + for (auto it = ++subplans.begin(); it != subplans.end(); it++) { + if (it->first.get().OperatorCount() > best_it->first.get().OperatorCount()) { + best_it = it; + } + } + + // Create a CTE! + ConvertSubplansToCTE(optimizer, op, best_it->second); + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/cte_inlining.cpp b/src/duckdb/src/optimizer/cte_inlining.cpp index 0b9e942ee..116d64768 100644 --- a/src/duckdb/src/optimizer/cte_inlining.cpp +++ b/src/duckdb/src/optimizer/cte_inlining.cpp @@ -55,10 +55,14 @@ static bool ContainsLimit(const LogicalOperator &op) { return false; } -static bool EndsInAggregateOrDistinct(const LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY || - op.type == LogicalOperatorType::LOGICAL_DISTINCT) { +bool CTEInlining::EndsInAggregateOrDistinct(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_WINDOW: return true; + default: + break; } if (op.children.size() != 1) { return false; @@ -146,8 +150,7 @@ void CTEInlining::TryInlining(unique_ptr &op) { } } -bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, - bool requires_copy) { +bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, bool requires_copy) { if (op->type == LogicalOperatorType::LOGICAL_CTE_REF) { auto &cteref = op->Cast(); auto &cte = materialized_cte.Cast(); diff --git a/src/duckdb/src/optimizer/deliminator.cpp b/src/duckdb/src/optimizer/deliminator.cpp index 0d24635f6..210d17a87 100644 --- a/src/duckdb/src/optimizer/deliminator.cpp +++ b/src/duckdb/src/optimizer/deliminator.cpp @@ -214,7 +214,10 @@ bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, cons auto &other_colref = other_side.Cast(); replacement_bindings.emplace_back(delim_colref.binding, other_colref.binding); - if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + // Only add IS NOT NULL filter for regular equality/inequality comparisons + // Do NOT add for DISTINCT FROM variants, as they handle NULL correctly + if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM && + cond.comparison != ExpressionType::COMPARE_DISTINCT_FROM) { auto is_not_null_expr = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); is_not_null_expr->children.push_back(other_side.Copy()); @@ -334,6 +337,7 @@ bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_ auto &colref = delim_side.Cast(); if (colref.binding == traced_binding) { auto join_comparison = join_condition.comparison; + auto original_join_comparison = join_condition.comparison; // Save original for later check if (delim_condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || delim_condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { // We need to compare NULL values @@ -352,7 +356,10 @@ bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_ // join condition was a not equal and filtered out all NULLS. // DELIM JOIN need to do that for not DELIM_GET side. Easiest way is to change the // comparison expression type. See duckdb/duckdb#16803 - if (delim_join.join_type != JoinType::MARK) { + // Only convert if the ORIGINAL join had != or = (not DISTINCT FROM variants) + if (delim_join.join_type != JoinType::MARK && + original_join_comparison != ExpressionType::COMPARE_DISTINCT_FROM && + original_join_comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { if (delim_condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM) { delim_condition.comparison = ExpressionType::COMPARE_NOTEQUAL; } diff --git a/src/duckdb/src/optimizer/empty_result_pullup.cpp b/src/duckdb/src/optimizer/empty_result_pullup.cpp index 74128bc90..daf3c3f37 100644 --- a/src/duckdb/src/optimizer/empty_result_pullup.cpp +++ b/src/duckdb/src/optimizer/empty_result_pullup.cpp @@ -40,9 +40,18 @@ unique_ptr EmptyResultPullup::PullUpEmptyJoinChildren(unique_pt } break; } - // TODO: For ANTI joins, if the right child is empty, you can replace the whole join with - // the left child - case JoinType::ANTI: + // For ANTI joins, if the right child is empty, the whole join collapses to the left child + case JoinType::ANTI: { + if (op->children[1]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT && + op->type != LogicalOperatorType::LOGICAL_EXCEPT) { + op = std::move(op->children[0]); + break; + } + if (op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + op = make_uniq(std::move(op)); + } + break; + } case JoinType::MARK: case JoinType::SINGLE: case JoinType::LEFT: { diff --git a/src/duckdb/src/optimizer/expression_rewriter.cpp b/src/duckdb/src/optimizer/expression_rewriter.cpp index c8836a380..e21b5bcfd 100644 --- a/src/duckdb/src/optimizer/expression_rewriter.cpp +++ b/src/duckdb/src/optimizer/expression_rewriter.cpp @@ -55,7 +55,7 @@ unique_ptr ExpressionRewriter::ConstantOrNull(vector(value)); return make_uniq(type, func, std::move(children), ConstantOrNull::Bind(std::move(value))); } diff --git a/src/duckdb/src/optimizer/filter_combiner.cpp b/src/duckdb/src/optimizer/filter_combiner.cpp index 8e4a295b4..f7099c9a1 100644 --- a/src/duckdb/src/optimizer/filter_combiner.cpp +++ b/src/duckdb/src/optimizer/filter_combiner.cpp @@ -1,6 +1,8 @@ #include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/string_common.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/planner/expression/bound_between_expression.hpp" @@ -23,6 +25,7 @@ #include "duckdb/optimizer/column_lifetime_analyzer.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "utf8proc_wrapper.hpp" namespace duckdb { @@ -281,6 +284,35 @@ static bool SupportedFilterComparison(ExpressionType expression_type) { } } +bool FilterCombiner::FindNextLegalUTF8(string &prefix_string) { + // find the start of the last codepoint + idx_t last_codepoint_start; + for (last_codepoint_start = prefix_string.size(); last_codepoint_start > 0; last_codepoint_start--) { + if (IsCharacter(prefix_string[last_codepoint_start - 1])) { + break; + } + } + if (last_codepoint_start == 0) { + throw InvalidInputException("Invalid UTF8 found in string \"%s\"", prefix_string); + } + last_codepoint_start--; + int codepoint_size; + auto codepoint = Utf8Proc::UTF8ToCodepoint(prefix_string.c_str() + last_codepoint_start, codepoint_size) + 1; + if (codepoint >= 0xD800 && codepoint <= 0xDFFF) { + // next codepoint falls within surrogate range increment to next valid character + codepoint = 0xE000; + } + char next_codepoint_text[4]; + int next_codepoint_size; + if (!Utf8Proc::CodepointToUtf8(codepoint, next_codepoint_size, next_codepoint_text)) { + // invalid codepoint + return false; + } + auto s = static_cast(next_codepoint_size); + prefix_string = prefix_string.substr(0, last_codepoint_start) + string(next_codepoint_text, s); + return true; +} + bool TypeSupportsConstantFilter(const LogicalType &type) { if (TypeIsNumeric(type.InternalType())) { return true; @@ -396,11 +428,14 @@ FilterPushdownResult FilterCombiner::TryPushdownPrefixFilter(TableFilterSet &tab auto &column_index = column_ids[column_ref.binding.column_index]; //! Replace prefix with a set of comparisons auto lower_bound = make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix_string)); - prefix_string[prefix_string.size() - 1]++; - auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(prefix_string)); table_filters.PushFilter(column_index, std::move(lower_bound)); - table_filters.PushFilter(column_index, std::move(upper_bound)); - return FilterPushdownResult::PUSHED_DOWN_FULLY; + if (FilterCombiner::FindNextLegalUTF8(prefix_string)) { + auto upper_bound = make_uniq(ExpressionType::COMPARE_LESSTHAN, Value(prefix_string)); + table_filters.PushFilter(column_index, std::move(upper_bound)); + return FilterPushdownResult::PUSHED_DOWN_FULLY; + } + // could not find next legal utf8 string - skip upper bound + return FilterPushdownResult::NO_PUSHDOWN; } FilterPushdownResult FilterCombiner::TryPushdownLikeFilter(TableFilterSet &table_filters, @@ -907,6 +942,12 @@ FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &com idx_t left_equivalence_set = GetEquivalenceSet(left_node); idx_t right_equivalence_set = GetEquivalenceSet(right_node); if (left_equivalence_set == right_equivalence_set) { + if (comparison.GetExpressionType() == ExpressionType::COMPARE_GREATERTHAN || + comparison.GetExpressionType() == ExpressionType::COMPARE_LESSTHAN) { + // non equal comparison has equal equivalence set, then it is unsatisfiable + // e.g., j > i AND i < j is unsatisfiable + return FilterResult::UNSATISFIABLE; + } // this equality filter already exists, prune it return FilterResult::SUCCESS; } diff --git a/src/duckdb/src/optimizer/filter_pullup.cpp b/src/duckdb/src/optimizer/filter_pullup.cpp index 219611387..f9ebb63c3 100644 --- a/src/duckdb/src/optimizer/filter_pullup.cpp +++ b/src/duckdb/src/optimizer/filter_pullup.cpp @@ -6,6 +6,7 @@ #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/planner/operator/logical_join.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" namespace duckdb { @@ -26,6 +27,7 @@ unique_ptr FilterPullup::Rewrite(unique_ptr op case LogicalOperatorType::LOGICAL_EXCEPT: return PullupSetOperation(std::move(op)); case LogicalOperatorType::LOGICAL_DISTINCT: + return PullupDistinct(std::move(op)); case LogicalOperatorType::LOGICAL_ORDER_BY: { // we can just pull directly through these operations without any rewriting op->children[0] = Rewrite(std::move(op->children[0])); @@ -115,6 +117,18 @@ unique_ptr FilterPullup::PullupCrossProduct(unique_ptr FilterPullup::PullupDistinct(unique_ptr op) { + const auto &distinct = op->Cast(); + if (distinct.distinct_type == DistinctType::DISTINCT) { + // Can pull up through a DISTINCT + op->children[0] = Rewrite(std::move(op->children[0])); + return op; + } + // Cannot pull up through a DISTINCT ON (see #19327) + D_ASSERT(distinct.distinct_type == DistinctType::DISTINCT_ON); + return FinishPullup(std::move(op)); +} + unique_ptr FilterPullup::GeneratePullupFilter(unique_ptr child, vector> &expressions) { unique_ptr filter = make_uniq(); diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp index c4f7bb04b..7c13386d9 100644 --- a/src/duckdb/src/optimizer/filter_pushdown.cpp +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -208,17 +208,23 @@ unique_ptr FilterPushdown::PushdownJoin(unique_ptrfilter)); D_ASSERT(result != FilterResult::UNSUPPORTED); - (void)result; + if (result == FilterResult::UNSATISFIABLE) { + // one of the filters is unsatisfiable - abort filter pushdown + return FilterResult::UNSATISFIABLE; + } } filters.clear(); + return FilterResult::SUCCESS; } FilterResult FilterPushdown::AddFilter(unique_ptr expr) { - PushFilters(); + if (PushFilters() == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } // split up the filters by AND predicate vector> expressions; expressions.push_back(std::move(expr)); @@ -276,51 +282,52 @@ unique_ptr FilterPushdown::PushFinalFilters(unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - for (idx_t i = 0; i < filters.size(); i++) { - auto &f = *filters[i]; - for (auto &child : op->children) { - FilterPushdown pushdown(optimizer, convert_mark_joins); +unique_ptr FilterPushdown::PushFiltersIntoDelimJoin(unique_ptr op) { + for (idx_t i = 0; i < filters.size(); i++) { + auto &f = *filters[i]; + for (auto &child : op->children) { + FilterPushdown pushdown(optimizer, convert_mark_joins); - // check if filter bindings can be applied to the child bindings. - auto child_bindings = child->GetColumnBindings(); - unordered_set child_bindings_table; - for (auto &binding : child_bindings) { - child_bindings_table.insert(binding.table_index); - } + // check if filter bindings can be applied to the child bindings. + auto child_bindings = child->GetColumnBindings(); + unordered_set child_bindings_table; + for (auto &binding : child_bindings) { + child_bindings_table.insert(binding.table_index); + } - // Check if ALL bindings of the filter are present in the child - bool should_push = true; - for (auto &binding : f.bindings) { - if (child_bindings_table.find(binding) == child_bindings_table.end()) { - should_push = false; - break; - } + // Check if ALL bindings of the filter are present in the child + bool should_push = true; + for (auto &binding : f.bindings) { + if (child_bindings_table.find(binding) == child_bindings_table.end()) { + should_push = false; + break; } + } - if (!should_push) { - continue; - } + if (!should_push) { + continue; + } - // copy the filter - auto filter_copy = f.filter->Copy(); - if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { - return make_uniq(std::move(op)); - } + // copy the filter + auto filter_copy = f.filter->Copy(); + if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { + return make_uniq(std::move(op)); + } - // push the filter into the child. - pushdown.GenerateFilters(); - child = pushdown.Rewrite(std::move(child)); + // push the filter into the child. + pushdown.GenerateFilters(); + child = pushdown.Rewrite(std::move(child)); - // Don't push same filter again - filters.erase_at(i); - i--; - break; - } + // Don't push same filter again + filters.erase_at(i); + i--; + break; } } + return op; +} +unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { // unhandled type, first perform filter pushdown in its children for (auto &child : op->children) { FilterPushdown pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/join_elimination.cpp b/src/duckdb/src/optimizer/join_elimination.cpp new file mode 100644 index 000000000..61bb84fc5 --- /dev/null +++ b/src/duckdb/src/optimizer/join_elimination.cpp @@ -0,0 +1,317 @@ +#include "duckdb/optimizer/join_elimination.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/logical_operator_type.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_distinct.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include + +namespace duckdb { +void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptr parent, idx_t idx) { + if (op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + if (!parent) { + return; + } + D_ASSERT(!pipe_info.join_parent); + pipe_info.join_parent = parent; + pipe_info.join_index = idx; + left_child = CreateChildren(); + right_child = CreateChildren(); + left_child->OptimizeInternal(std::move(op.children[0])); + right_child->OptimizeInternal(std::move(op.children[1])); + return; + } + + VisitOperatorExpressions(op); + + switch (op.type) { + case LogicalOperatorType::LOGICAL_DISTINCT: { + auto &distinct = op.Cast(); + if (distinct.distinct_type != DistinctType::DISTINCT) { + break; + } + column_binding_set_t distinct_group; + if (distinct.distinct_targets[0]->type != ExpressionType::BOUND_COLUMN_REF) { + break; + } + idx_t table_idx = distinct.distinct_targets[0]->Cast().binding.table_index; + bool can_add = true; + for (auto &target : distinct.distinct_targets) { + if (target->type != ExpressionType::BOUND_COLUMN_REF) { + can_add = false; + break; + } + auto &col_ref = target->Cast(); + distinct_group.insert(col_ref.binding); + D_ASSERT(table_idx == col_ref.binding.table_index); + } + if (can_add) { + pipe_info.distinct_groups[table_idx] = std::move(distinct_group); + } + break; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggr = op.Cast(); + if (aggr.grouping_sets.size() > 1) { + break; + } + // only resolve group by columns for now + column_binding_set_t distinct_group; + idx_t table_idx = aggr.group_index; + for (idx_t i = 0; i < aggr.groups.size(); i++) { + distinct_group.insert(ColumnBinding(aggr.group_index, i)); + } + if (!distinct_group.empty()) { + pipe_info.distinct_groups[table_idx] = std::move(distinct_group); + } + break; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &projection = op.Cast(); + unordered_map> reference_records; + // for select distinct * from table, first projection then distinct. distinct_groups has record projection table + // id for select * from table group by col, first aggregate then projection. projection has aggregate table id. + + // before traverse children, first check whether any distinct group ref this projection + auto it = pipe_info.distinct_groups.find(projection.table_index); + if (it != pipe_info.distinct_groups.end()) { + column_binding_set_t new_distinct_group; + auto &expression = projection.expressions.get(it->second.begin()->column_index); + if (expression->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + // if the expression is not a column ref, we cannot eliminate the join + break; + } + bool could_add = true; + idx_t ref_id = expression->Cast().binding.table_index; + for (auto &col : it->second) { + auto &expression = projection.expressions.get(col.column_index); + if (expression->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + // if the expression is not a column ref, we cannot eliminate the join + could_add = false; + break; + } + auto &col_ref = expression->Cast(); + if (ref_id != col_ref.binding.table_index) { + could_add = false; + break; + } + new_distinct_group.insert(col_ref.binding); + } + if (could_add) { + pipe_info.distinct_groups[ref_id] = std::move(new_distinct_group); + } + } + break; + } + case LogicalOperatorType::LOGICAL_GET: { + auto &get = op.Cast(); + if (!get.table_filters.filters.empty()) { + pipe_info.has_filter = true; + } + break; + } + default: + break; + } + + if (op.children.size() == 1) { + OptimizeChildren(*op.children[0], op, idx); + } else { + children_root = op; + for (auto &child : op.children) { + auto child_optimizer = CreateChildren(); + child_optimizer->OptimizeInternal(std::move(child)); + children.emplace_back(std::move(child_optimizer)); + } + return; + } + + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &projection = op.Cast(); + // after traversed children, here check whether any distinct group added in children + unordered_map ref_table_columns; + for (idx_t idx = 0; idx < projection.expressions.size(); idx++) { + auto &expression = projection.expressions.get(idx); + if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + auto &col_ref = expression->Cast(); + auto distinct_group_it = pipe_info.distinct_groups.find(col_ref.binding.table_index); + if (distinct_group_it == pipe_info.distinct_groups.end()) { + continue; + } + if (ref_table_columns.find(col_ref.binding.table_index) == ref_table_columns.end()) { + auto ref = DistinctGroupRef(); + for (auto &col : distinct_group_it->second) { + ref.ref_column_ids.insert(col.column_index); + } + ref_table_columns[col_ref.binding.table_index] = ref; + } + ref_table_columns[col_ref.binding.table_index].distinct_group.insert( + ColumnBinding(projection.table_index, idx)); + ref_table_columns[col_ref.binding.table_index].ref_column_ids.erase(col_ref.binding.column_index); + } + } + for (auto &refs : ref_table_columns) { + if (refs.second.ref_column_ids.empty()) { + pipe_info.distinct_groups[projection.table_index] = std::move(refs.second.distinct_group); + } + } + break; + } + default: + D_ASSERT(op.type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN); + break; + } +} + +unique_ptr JoinElimination::Optimize(unique_ptr op) { + OptimizeInternal(std::move(op)); + if (pipe_info.join_parent || !children.empty()) { + pipe_info.root = TryEliminateJoin(); + } + return std::move(pipe_info.root); +} + +void JoinElimination::OptimizeInternal(unique_ptr op) { + pipe_info.root = std::move(op); + OptimizeChildren(*pipe_info.root, nullptr, 0); +} + +unique_ptr JoinElimination::TryEliminateJoin() { + D_ASSERT(pipe_info.root); + if (!children.empty()) { + D_ASSERT(!pipe_info.join_parent); + D_ASSERT(children_root); + D_ASSERT(children.size() == children_root->children.size()); + + for (idx_t idx = 0; idx < children.size(); idx++) { + children_root->children[idx] = children[idx]->TryEliminateJoin(); + } + return std::move(pipe_info.root); + } + if (!pipe_info.join_parent) { + return std::move(pipe_info.root); + } + + auto join_parent = pipe_info.join_parent; + + auto &join_op = pipe_info.join_parent->children[pipe_info.join_index]; + join_op->children[0] = left_child->TryEliminateJoin(); + join_op->children[1] = right_child->TryEliminateJoin(); + + auto &join = join_op->Cast(); + bool is_output_unique = false; + idx_t inner_idx = 1; + idx_t outer_idx = 0; + switch (join.join_type) { + case JoinType::LEFT: + break; + case JoinType::SINGLE: { + is_output_unique = true; + break; + case JoinType::RIGHT: + inner_idx = 0; + outer_idx = 1; + break; + } + default: + return std::move(pipe_info.root); + } + auto &inner_child = inner_idx == 0 ? left_child : right_child; + if (inner_child->pipe_info.has_filter) { + return std::move(pipe_info.root); + } + if (join.filter_pushdown) { + return std::move(pipe_info.root); + } + auto inner_bindings = join.children[inner_idx]->GetColumnBindings(); + // ensure join output columns only contains outer table columns + for (auto &binding : inner_bindings) { + if (pipe_info.ref_table_ids.find(binding.table_index) != pipe_info.ref_table_ids.end()) { + return std::move(pipe_info.root); + } + } + + if (inner_idx == 1) { + for (auto &distinct : right_child->pipe_info.distinct_groups) { + pipe_info.distinct_groups[distinct.first] = distinct.second; + } + } else { + for (auto &distinct : left_child->pipe_info.distinct_groups) { + pipe_info.distinct_groups[distinct.first] = distinct.second; + } + } + + if (pipe_info.distinct_groups.empty()) { + return std::move(pipe_info.root); + } + // 1. TODO: guarantee by primary/foreign key + + if (!is_output_unique) { + is_output_unique = true; + // 2. inner table join condition columns contains a whole distinct group + vector col_bindings; + for (auto &condition : join.conditions) { + if (condition.comparison != ExpressionType::COMPARE_EQUAL || + condition.left->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF || + condition.right->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + is_output_unique = false; + break; + } + auto inner_binding = inner_idx == 0 ? condition.left->Cast().binding + : condition.right->Cast().binding; + col_bindings.push_back(inner_binding); + } + if (is_output_unique && !ContainDistinctGroup(col_bindings)) { + is_output_unique = false; + } + } + if (!is_output_unique) { + // 3. join result columns in join condition contains a whole distinct group + auto outer_bindings = join.children[outer_idx]->GetColumnBindings(); + if (ContainDistinctGroup(outer_bindings)) { + is_output_unique = true; + } + } + + if (is_output_unique) { + join_parent->children[pipe_info.join_index] = std::move(join_op->children[outer_idx]); + } + return std::move(pipe_info.root); +} + +bool JoinElimination::ContainDistinctGroup(vector &column_bindings) { + D_ASSERT(!column_bindings.empty()); + auto &column_binding = column_bindings[0]; + auto it = pipe_info.distinct_groups.find(column_binding.table_index); + if (it == pipe_info.distinct_groups.end()) { + return false; + } + unordered_set used_column_ids; + for (auto &binding : column_bindings) { + if (it->second.find(binding) == it->second.end()) { + continue; + } + used_column_ids.emplace(binding.column_index); + } + return used_column_ids.size() == it->second.size(); +} + +unique_ptr JoinElimination::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { + pipe_info.ref_table_ids.insert(expr.binding.table_index); + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp index 02d6c2e45..74227ddda 100644 --- a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp @@ -146,6 +146,28 @@ void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, } } +bool JoinFilterPushdownOptimizer::IsFiltering(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_GET: { + auto &get = op->Cast(); + return !get.table_filters.filters.empty(); + } + case LogicalOperatorType::LOGICAL_FILTER: { + return true; + } + case LogicalOperatorType::LOGICAL_TOP_N: { + return true; + } + default: + for (const unique_ptr &child : op->children) { + if (IsFiltering(child)) { + return true; + } + } + return false; + } +} + void JoinFilterPushdownOptimizer::GenerateJoinFilters(LogicalComparisonJoin &join) { switch (join.join_type) { case JoinType::MARK: @@ -252,6 +274,14 @@ void JoinFilterPushdownOptimizer::GenerateJoinFilters(LogicalComparisonJoin &joi pushdown_info->min_max_aggregates.push_back(std::move(aggr_expr)); } } + if (!pushdown_info->probe_info.empty()) { + const auto &rhs_child = join.children[1]; + if (rhs_child->type == LogicalOperatorType::LOGICAL_DELIM_GET) { + pushdown_info->build_side_has_filter = IsFiltering(join.children[0]); + } else { + pushdown_info->build_side_has_filter = IsFiltering(join.children[1]); + } + } // set up the filter pushdown in the join itself join.filter_pushdown = std::move(pushdown_info); } diff --git a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp index 6dd086ffc..f09070ce3 100644 --- a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp +++ b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp @@ -22,9 +22,9 @@ bool CardinalityEstimator::EmptyFilter(FilterInfo &filter_info) { return false; } -void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { +void CardinalityEstimator::AddRelationStats(FilterInfo &filter_info) { D_ASSERT(filter_info.set.get().count >= 1); - for (const RelationsToTDom &r2tdom : relations_to_tdoms) { + for (const RelationsSetToStats &r2tdom : relation_set_stats) { auto &i_set = r2tdom.equivalent_relations; if (i_set.find(filter_info.left_binding) != i_set.end()) { // found an equivalent filter @@ -33,9 +33,9 @@ void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { } auto key = ColumnBinding(filter_info.left_binding.table_index, filter_info.left_binding.column_index); - RelationsToTDom new_r2tdom(column_binding_set_t({key})); + RelationsSetToStats new_r2tdom(column_binding_set_t({key})); - relations_to_tdoms.emplace_back(new_r2tdom); + relation_set_stats.emplace_back(new_r2tdom); } bool CardinalityEstimator::SingleColumnFilter(duckdb::FilterInfo &filter_info) { @@ -56,7 +56,7 @@ vector CardinalityEstimator::DetermineMatchingEquivalentSets(optional_ptr vector matching_equivalent_sets; idx_t equivalent_relation_index = 0; - for (const RelationsToTDom &r2tdom : relations_to_tdoms) { + for (const RelationsSetToStats &r2tdom : relation_set_stats) { auto &i_set = r2tdom.equivalent_relations; if (i_set.find(filter_info->left_binding) != i_set.end()) { matching_equivalent_sets.push_back(equivalent_relation_index); @@ -77,18 +77,18 @@ void CardinalityEstimator::AddToEquivalenceSets(optional_ptr filter_ // an equivalence relation is connecting two sets of equivalence relations // so push all relations from the second set into the first. Later we will delete // the second set. - for (ColumnBinding i : relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations) { - relations_to_tdoms.at(matching_equivalent_sets[0]).equivalent_relations.insert(i); + for (ColumnBinding i : relation_set_stats.at(matching_equivalent_sets[1]).equivalent_relations) { + relation_set_stats.at(matching_equivalent_sets[0]).equivalent_relations.insert(i); } - for (auto &column_name : relations_to_tdoms.at(matching_equivalent_sets[1]).column_names) { - relations_to_tdoms.at(matching_equivalent_sets[0]).column_names.push_back(column_name); + for (auto &column_name : relation_set_stats.at(matching_equivalent_sets[1]).column_names) { + relation_set_stats.at(matching_equivalent_sets[0]).column_names.push_back(column_name); } - relations_to_tdoms.at(matching_equivalent_sets[1]).equivalent_relations.clear(); - relations_to_tdoms.at(matching_equivalent_sets[1]).column_names.clear(); - relations_to_tdoms.at(matching_equivalent_sets[0]).filters.push_back(filter_info); + relation_set_stats.at(matching_equivalent_sets[1]).equivalent_relations.clear(); + relation_set_stats.at(matching_equivalent_sets[1]).column_names.clear(); + relation_set_stats.at(matching_equivalent_sets[0]).filters.push_back(filter_info); // add all values of one set to the other, delete the empty one } else if (matching_equivalent_sets.size() == 1) { - auto &tdom_i = relations_to_tdoms.at(matching_equivalent_sets.at(0)); + auto &tdom_i = relation_set_stats.at(matching_equivalent_sets.at(0)); tdom_i.equivalent_relations.insert(filter_info->left_binding); tdom_i.equivalent_relations.insert(filter_info->right_binding); tdom_i.filters.push_back(filter_info); @@ -96,8 +96,8 @@ void CardinalityEstimator::AddToEquivalenceSets(optional_ptr filter_ column_binding_set_t tmp; tmp.insert(filter_info->left_binding); tmp.insert(filter_info->right_binding); - relations_to_tdoms.emplace_back(tmp); - relations_to_tdoms.back().filters.push_back(filter_info); + relation_set_stats.emplace_back(tmp); + relation_set_stats.back().filters.push_back(filter_info); } } @@ -108,7 +108,7 @@ void CardinalityEstimator::InitEquivalentRelations(const vector GetEdges(vector &relations_to_tdom, +vector GetEdges(vector &relations_to_tdom, JoinRelationSet &requested_set) { vector res; for (auto &relation_2_tdom : relations_to_tdom) { @@ -213,6 +214,8 @@ JoinRelationSet &CardinalityEstimator::UpdateNumeratorRelations(Subgraph2Denomin } } +// Given two relations, here is where we considers the filter(s) that join them. +// This could use some work when it comes to join conditions that are not equality join conditions double CardinalityEstimator::CalculateUpdatedDenom(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter) { double new_denom = left.denom * right.denom; @@ -226,8 +229,8 @@ double CardinalityEstimator::CalculateUpdatedDenom(Subgraph2Denominator left, Su } }); if (comparison_type == ExpressionType::INVALID) { - new_denom *= - filter.has_tdom_hll ? static_cast(filter.tdom_hll) : static_cast(filter.tdom_no_hll); + new_denom *= filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) + : static_cast(filter.distinct_count_no_hll); // no comparison is taking place, so the denominator is just the product of the left and right return new_denom; } @@ -238,8 +241,8 @@ double CardinalityEstimator::CalculateUpdatedDenom(Subgraph2Denominator left, Su case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOT_DISTINCT_FROM: // extra ratio stays 1 - extra_ratio = - filter.has_tdom_hll ? static_cast(filter.tdom_hll) : static_cast(filter.tdom_no_hll); + extra_ratio = filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) + : static_cast(filter.distinct_count_no_hll); break; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_LESSTHAN: @@ -248,8 +251,8 @@ double CardinalityEstimator::CalculateUpdatedDenom(Subgraph2Denominator left, Su case ExpressionType::COMPARE_NOTEQUAL: case ExpressionType::COMPARE_DISTINCT_FROM: // Assume this blows up, but use the tdom to bound it a bit - extra_ratio = - filter.has_tdom_hll ? static_cast(filter.tdom_hll) : static_cast(filter.tdom_no_hll); + extra_ratio = filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) + : static_cast(filter.distinct_count_no_hll); extra_ratio = pow(extra_ratio, 2.0 / 3.0); break; default: @@ -291,12 +294,12 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { // edges are guaranteed to be in order of largest tdom to smallest tdom. unordered_set unused_edge_tdoms; - auto edges = GetEdges(relations_to_tdoms, set); + auto edges = GetEdges(relation_set_stats, set); for (auto &edge : edges) { if (subgraphs.size() == 1 && subgraphs.at(0).relations->ToString() == set.ToString()) { // the first subgraph has connected all the desired relations, just skip the rest of the edges - if (edge.has_tdom_hll) { - unused_edge_tdoms.insert(edge.tdom_hll); + if (edge.has_distinct_count_hll) { + unused_edge_tdoms.insert(edge.distinct_count_hll); } continue; } @@ -392,6 +395,18 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { return DenomInfo(*subgraphs.at(0).numerator_relations, 1, subgraphs.at(0).denom * denom_multiplier); } +// Cardinality is calculatd using logic found in +// https://blobs.duckdb.org/papers/tom-ebergen-msc-thesis-join-order-optimization-with-almost-no-statistics.pdf TL;DR +// Cardinality is estimated based on cardinality of base tables and the distinct counts of joined columns. If you have +// two tables A and B joined using A.x = B.y we assume that each tuple in A will match ~ B/(distinct(y)) tuples in B. +// The cardinality estimation then becomes (|A|x|B|) / max(distinct(x), distinct(y)). +// If there are extra joins, you can add the cardinality of the table to the numerator, and the +// distinct count of the join condition to the denominator. +// One benefit of this cardinality estimation formula is that it is associative and commutative, which means regardless +// of the order of the joins/join tree, the cardinality estimate will always be the same. The drawback of this current +// implementation, however, is that it only considers equality join conditions. Some modification have been made for +// comparison types like <, <=, >, >=, !=, but only a "penalty" was introduced, and the calculated cardinality is not +// based on stats (see CalculateUpdatedDenom()). template <> double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { if (relation_set_2_cardinality.find(new_set.ToString()) != relation_set_2_cardinality.end()) { @@ -400,6 +415,8 @@ double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set // can happen if a table has cardinality 0, or a tdom is set to 0 auto denom = GetDenominator(new_set); + // we pass numerator relations, because for semi and anti joins, we don't want to + // include cardinalities of relations on the RHS of a semi/anti join. auto numerator = GetNumerator(denom.numerator_relations); double result = numerator / denom.denominator; @@ -418,17 +435,17 @@ idx_t CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) return (idx_t)cardinality_as_double; } -bool SortTdoms(const RelationsToTDom &a, const RelationsToTDom &b) { - if (a.has_tdom_hll && b.has_tdom_hll) { - return a.tdom_hll > b.tdom_hll; +bool SortTdoms(const RelationsSetToStats &a, const RelationsSetToStats &b) { + if (a.has_distinct_count_hll && b.has_distinct_count_hll) { + return a.distinct_count_hll > b.distinct_count_hll; } - if (a.has_tdom_hll) { - return a.tdom_hll > b.tdom_no_hll; + if (a.has_distinct_count_hll) { + return a.distinct_count_hll > b.distinct_count_no_hll; } - if (b.has_tdom_hll) { - return a.tdom_no_hll > b.tdom_hll; + if (b.has_distinct_count_hll) { + return a.distinct_count_no_hll > b.distinct_count_hll; } - return a.tdom_no_hll > b.tdom_no_hll; + return a.distinct_count_no_hll > b.distinct_count_no_hll; } void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { @@ -442,7 +459,7 @@ void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { @@ -456,19 +473,21 @@ void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, //! the cardinality // Update the relation_to_tdom set with the estimated distinct count (or tdom) calculated above auto key = ColumnBinding(relation_id, i); - for (auto &relation_to_tdom : relations_to_tdoms) { + for (auto &relation_to_tdom : relation_set_stats) { column_binding_set_t i_set = relation_to_tdom.equivalent_relations; if (i_set.find(key) == i_set.end()) { continue; } auto distinct_count = stats.column_distinct_count.at(i); - if (distinct_count.from_hll && relation_to_tdom.has_tdom_hll) { - relation_to_tdom.tdom_hll = MaxValue(relation_to_tdom.tdom_hll, distinct_count.distinct_count); - } else if (distinct_count.from_hll && !relation_to_tdom.has_tdom_hll) { - relation_to_tdom.has_tdom_hll = true; - relation_to_tdom.tdom_hll = distinct_count.distinct_count; + if (distinct_count.from_hll && relation_to_tdom.has_distinct_count_hll) { + relation_to_tdom.distinct_count_hll = + MaxValue(relation_to_tdom.distinct_count_hll, distinct_count.distinct_count); + } else if (distinct_count.from_hll && !relation_to_tdom.has_distinct_count_hll) { + relation_to_tdom.has_distinct_count_hll = true; + relation_to_tdom.distinct_count_hll = distinct_count.distinct_count; } else { - relation_to_tdom.tdom_no_hll = MinValue(distinct_count.distinct_count, relation_to_tdom.tdom_no_hll); + relation_to_tdom.distinct_count_no_hll = + MinValue(distinct_count.distinct_count, relation_to_tdom.distinct_count_no_hll); } break; } @@ -477,9 +496,9 @@ void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, // LCOV_EXCL_START -void CardinalityEstimator::AddRelationNamesToTdoms(vector &stats) { +void CardinalityEstimator::AddRelationNamesToRelationStats(vector &stats) { #ifdef DEBUG - for (auto &total_domain : relations_to_tdoms) { + for (auto &total_domain : relation_set_stats) { for (auto &binding : total_domain.equivalent_relations) { D_ASSERT(binding.table_index < stats.size()); string column_name; @@ -494,14 +513,15 @@ void CardinalityEstimator::AddRelationNamesToTdoms(vector &stats) #endif } -void CardinalityEstimator::PrintRelationToTdomInfo() { - for (auto &total_domain : relations_to_tdoms) { +void CardinalityEstimator::PrintRelationStats() { + for (auto &total_domain : relation_set_stats) { string domain = "Following columns have the same distinct count: "; for (auto &column_name : total_domain.column_names) { domain += column_name + ", "; } - bool have_hll = total_domain.has_tdom_hll; - domain += "\n TOTAL DOMAIN = " + to_string(have_hll ? total_domain.tdom_hll : total_domain.tdom_no_hll); + bool have_hll = total_domain.has_distinct_count_hll; + domain += "\n TOTAL DOMAIN = " + + to_string(have_hll ? total_domain.distinct_count_hll : total_domain.distinct_count_no_hll); Printer::Print(domain); } } diff --git a/src/duckdb/src/optimizer/join_order/cost_model.cpp b/src/duckdb/src/optimizer/join_order/cost_model.cpp index bfe64412f..ef1e23972 100644 --- a/src/duckdb/src/optimizer/join_order/cost_model.cpp +++ b/src/duckdb/src/optimizer/join_order/cost_model.cpp @@ -8,6 +8,8 @@ CostModel::CostModel(QueryGraphManager &query_graph_manager) : query_graph_manager(query_graph_manager), cardinality_estimator() { } +// Currently cost of a join only factors in the cardinalities. +// If join types and join algorithms are to be considered, they should be added here. double CostModel::ComputeCost(DPJoinNode &left, DPJoinNode &right) { auto &combination = query_graph_manager.set_manager.Union(left.set, right.set); auto join_card = cardinality_estimator.EstimateCardinalityWithSet(combination); diff --git a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp index b90d22b0f..767b918c4 100644 --- a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp @@ -25,7 +25,6 @@ JoinOrderOptimizer JoinOrderOptimizer::CreateChildOptimizer() { unique_ptr JoinOrderOptimizer::Optimize(unique_ptr plan, optional_ptr stats) { - if (depth > query_graph_manager.context.config.max_expression_depth) { // Very deep plans will eventually consume quite some stack space // Returning the current plan is always a valid choice diff --git a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp index fc282aba9..ec39dbd05 100644 --- a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp @@ -102,7 +102,6 @@ const reference_map_t> &PlanEnumerator:: unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, const vector> &possible_connections, DPJoinNode &left, DPJoinNode &right) { - // FIXME: should consider different join algorithms, should we pick a join algorithm here as well? (probably) optional_ptr best_connection = possible_connections.back().get(); // cross products are technically still connections, but the filter expression is a null_ptr @@ -452,7 +451,7 @@ void PlanEnumerator::InitLeafPlans() { auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); cost_model.cardinality_estimator.InitEquivalentRelations(query_graph_manager.GetFilterBindings()); - cost_model.cardinality_estimator.AddRelationNamesToTdoms(relation_stats); + cost_model.cardinality_estimator.AddRelationNamesToRelationStats(relation_stats); // then update the total domains based on the cardinalities of each relation. for (idx_t i = 0; i < relation_stats.size(); i++) { diff --git a/src/duckdb/src/optimizer/join_order/query_graph.cpp b/src/duckdb/src/optimizer/join_order/query_graph.cpp index beb9e1521..01e167fec 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph.cpp @@ -79,7 +79,6 @@ void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, void QueryGraphEdges::EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, const std::function &callback) const { - for (auto &neighbor : info.get().neighbors) { if (callback(*neighbor)) { return; diff --git a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp index 94ee1a2c8..28f6a9eb3 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp @@ -243,7 +243,6 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorsecond; if (!dp_entry->second->is_leaf) { - // generate the left and right children auto left = GenerateJoins(extracted_relations, node->left_set); auto right = GenerateJoins(extracted_relations, node->right_set); diff --git a/src/duckdb/src/optimizer/join_order/relation_manager.cpp b/src/duckdb/src/optimizer/join_order/relation_manager.cpp index 4916f662e..78f61049b 100644 --- a/src/duckdb/src/optimizer/join_order/relation_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_manager.cpp @@ -46,7 +46,6 @@ void RelationManager::AddAggregateOrWindowRelation(LogicalOperator &op, optional void RelationManager::AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats) { - // if parent is null, then this is a root relation // if parent is not null, it should have multiple children D_ASSERT(!parent || parent->children.size() >= 2); @@ -54,6 +53,10 @@ void RelationManager::AddRelation(LogicalOperator &op, optional_ptr limit_op, RelationS } } +void RelationManager::AddRelationWithChildren(JoinOrderOptimizer &optimizer, LogicalOperator &op, + LogicalOperator &input_op, optional_ptr parent, + RelationStats &child_stats, optional_ptr limit_op, + vector> &datasource_filters) { + D_ASSERT(!op.children.empty()); + auto child_optimizer = optimizer.CreateChildOptimizer(); + op.children[0] = child_optimizer.Optimize(std::move(op.children[0]), &child_stats); + if (!datasource_filters.empty()) { + child_stats.cardinality = LossyNumericCast(static_cast(child_stats.cardinality) * + RelationStatisticsHelper::DEFAULT_SELECTIVITY); + } + ModifyStatsIfLimit(limit_op.get(), child_stats); + AddRelation(input_op, parent, child_stats); +} + bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, LogicalOperator &input_op, vector> &filter_operators, optional_ptr parent) { @@ -279,15 +297,7 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica case LogicalOperatorType::LOGICAL_UNNEST: { // optimize children of unnest RelationStats child_stats; - auto child_optimizer = optimizer.CreateChildOptimizer(); - op->children[0] = child_optimizer.Optimize(std::move(op->children[0]), &child_stats); - // the extracted cardinality should be set for window - if (!datasource_filters.empty()) { - child_stats.cardinality = LossyNumericCast(static_cast(child_stats.cardinality) * - RelationStatisticsHelper::DEFAULT_SELECTIVITY); - } - ModifyStatsIfLimit(limit_op.get(), child_stats); - AddRelation(input_op, parent, child_stats); + AddRelationWithChildren(optimizer, *op, input_op, parent, child_stats, limit_op, datasource_filters); return true; } case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { @@ -345,6 +355,14 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica case LogicalOperatorType::LOGICAL_GET: { // TODO: Get stats from a logical GET auto &get = op->Cast(); + // this is a get that *most likely* has a function (like unnest or json_each). + // there are new bindings for output of the function, but child bindings also exist, and can + // be used in joins + if (!op->children.empty()) { + RelationStats child_stats; + AddRelationWithChildren(optimizer, *op, input_op, parent, child_stats, limit_op, datasource_filters); + return true; + } auto stats = RelationStatisticsHelper::ExtractGetStats(get, context); // if there is another logical filter that could not be pushed down into the // table scan, apply another selectivity. @@ -542,7 +560,6 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op auto &join = f_op.Cast(); D_ASSERT(join.expressions.empty()); if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { - auto conjunction_expression = make_uniq(ExpressionType::CONJUNCTION_AND); // create a conjunction expression for the semi join. // It's possible multiple LHS relations have a condition in diff --git a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp index 6cf8dfea5..ad2ce3341 100644 --- a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp @@ -21,14 +21,14 @@ static ExpressionBinding GetChildColumnBinding(Expression &expr) { auto &func = expr.Cast(); // no children some sort of gen_random_uuid() or equivalent. if (func.children.empty()) { - ret.found_expression = true; + ret.expression = expr; ret.expression_is_constant = true; return ret; } break; } case ExpressionClass::BOUND_COLUMN_REF: { - ret.found_expression = true; + ret.expression = expr; auto &new_col_ref = expr.Cast(); ret.child_binding = ColumnBinding(new_col_ref.binding.table_index, new_col_ref.binding.column_index); return ret; @@ -38,16 +38,21 @@ static ExpressionBinding GetChildColumnBinding(Expression &expr) { case ExpressionClass::BOUND_DEFAULT: case ExpressionClass::BOUND_PARAMETER: case ExpressionClass::BOUND_REF: - ret.found_expression = true; + ret.expression = expr; ret.expression_is_constant = true; return ret; default: break; } ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { + if (ret.FoundColumnRef()) { + //! Already found a column ref expression + return; + } auto recursive_result = GetChildColumnBinding(*child); - if (recursive_result.found_expression) { + if (recursive_result.FoundExpression()) { ret = recursive_result; + return; } }); // we didn't find a Bound Column Ref @@ -163,7 +168,7 @@ RelationStats RelationStatisticsHelper::ExtractProjectionStats(LogicalProjection for (auto &expr : proj.expressions) { proj_stats.column_names.push_back(expr->GetName()); auto res = GetChildColumnBinding(*expr); - D_ASSERT(res.found_expression); + D_ASSERT(res.FoundExpression()); if (res.expression_is_constant) { proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); } else { diff --git a/src/duckdb/src/optimizer/late_materialization.cpp b/src/duckdb/src/optimizer/late_materialization.cpp index 4e5b0f13e..3b19b3612 100644 --- a/src/duckdb/src/optimizer/late_materialization.cpp +++ b/src/duckdb/src/optimizer/late_materialization.cpp @@ -1,4 +1,6 @@ #include "duckdb/optimizer/late_materialization.hpp" + +#include "duckdb/optimizer/late_materialization_helper.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" @@ -22,53 +24,6 @@ LateMaterialization::LateMaterialization(Optimizer &optimizer) : optimizer(optim max_row_count = DBConfig::GetSetting(optimizer.context); } -vector LateMaterialization::GetOrInsertRowIds(LogicalGet &get) { - auto &column_ids = get.GetMutableColumnIds(); - - vector result; - for (idx_t r_idx = 0; r_idx < row_id_column_ids.size(); ++r_idx) { - // check if it is already projected - auto row_id_column_id = row_id_column_ids[r_idx]; - auto &row_id_column = row_id_columns[r_idx]; - optional_idx row_id_index; - for (idx_t i = 0; i < column_ids.size(); ++i) { - if (column_ids[i].GetPrimaryIndex() == row_id_column_id) { - // already projected - return the id - row_id_index = i; - break; - } - } - if (row_id_index.IsValid()) { - result.push_back(row_id_index.GetIndex()); - continue; - } - // row id is not yet projected - push it and return the new index - column_ids.push_back(ColumnIndex(row_id_column_id)); - if (!get.projection_ids.empty()) { - get.projection_ids.push_back(column_ids.size() - 1); - } - if (!get.types.empty()) { - get.types.push_back(row_id_column.type); - } - result.push_back(column_ids.size() - 1); - } - return result; -} - -unique_ptr LateMaterialization::ConstructLHS(LogicalGet &get) { - // we need to construct a new scan of the same table - auto table_index = optimizer.binder.GenerateTableIndex(); - auto new_get = make_uniq(table_index, get.function, get.bind_data->Copy(), get.returned_types, - get.names, get.virtual_columns); - new_get->GetMutableColumnIds() = get.GetColumnIds(); - new_get->projection_ids = get.projection_ids; - new_get->parameters = get.parameters; - new_get->named_parameters = get.named_parameters; - new_get->input_table_types = get.input_table_types; - new_get->input_table_names = get.input_table_names; - return new_get; -} - vector LateMaterialization::ConstructRHS(unique_ptr &op) { // traverse down until we reach the LogicalGet vector> stack; @@ -80,7 +35,7 @@ vector LateMaterialization::ConstructRHS(unique_ptr(); - auto row_id_indexes = GetOrInsertRowIds(get); + auto row_id_indexes = LateMaterializationHelper::GetOrInsertRowIds(get, row_id_column_ids, row_id_columns); idx_t column_count = get.projection_ids.empty() ? get.GetColumnIds().size() : get.projection_ids.size(); D_ASSERT(column_count == get.GetColumnBindings().size()); @@ -281,12 +236,12 @@ bool LateMaterialization::TryLateMaterialization(unique_ptr &op // we need to ensure the operator returns exactly the same column bindings as before // construct the LHS from the LogicalGet - auto lhs = ConstructLHS(get); + auto lhs = LateMaterializationHelper::CreateLHSGet(get, optimizer.binder); // insert the row-id column on the left hand side auto &lhs_get = *lhs; auto lhs_index = lhs_get.table_index; auto lhs_columns = lhs_get.GetColumnIds().size(); - auto lhs_row_indexes = GetOrInsertRowIds(lhs_get); + auto lhs_row_indexes = LateMaterializationHelper::GetOrInsertRowIds(lhs_get, row_id_column_ids, row_id_columns); vector lhs_bindings; for (auto &lhs_row_index : lhs_row_indexes) { lhs_bindings.emplace_back(lhs_index, lhs_row_index); @@ -436,6 +391,11 @@ bool LateMaterialization::OptimizeLargeLimit(LogicalLimit &limit, idx_t limit_va } current_op = *current_op.get().children[0]; } + // if there are any filters we shouldn't do large limit optimization + auto &get = current_op.get().Cast(); + if (!get.table_filters.filters.empty()) { + return false; + } return true; } diff --git a/src/duckdb/src/optimizer/late_materialization_helper.cpp b/src/duckdb/src/optimizer/late_materialization_helper.cpp new file mode 100644 index 000000000..4b81ba3a1 --- /dev/null +++ b/src/duckdb/src/optimizer/late_materialization_helper.cpp @@ -0,0 +1,52 @@ +#include "duckdb/optimizer/late_materialization_helper.hpp" + +namespace duckdb { + +unique_ptr LateMaterializationHelper::CreateLHSGet(const LogicalGet &rhs, Binder &binder) { + // we need to construct a new scan of the same table + auto table_index = binder.GenerateTableIndex(); + auto new_get = make_uniq(table_index, rhs.function, rhs.bind_data->Copy(), rhs.returned_types, + rhs.names, rhs.virtual_columns); + new_get->GetMutableColumnIds() = rhs.GetColumnIds(); + new_get->projection_ids = rhs.projection_ids; + new_get->parameters = rhs.parameters; + new_get->named_parameters = rhs.named_parameters; + new_get->input_table_types = rhs.input_table_types; + new_get->input_table_names = rhs.input_table_names; + return new_get; +} + +vector LateMaterializationHelper::GetOrInsertRowIds(LogicalGet &get, const vector &row_id_column_ids, + const vector &row_id_columns) { + auto &column_ids = get.GetMutableColumnIds(); + + vector result; + for (idx_t r_idx = 0; r_idx < row_id_column_ids.size(); ++r_idx) { + // check if it is already projected + auto row_id_column_id = row_id_column_ids[r_idx]; + auto &row_id_column = row_id_columns[r_idx]; + optional_idx row_id_index; + for (idx_t i = 0; i < column_ids.size(); ++i) { + if (column_ids[i].GetPrimaryIndex() == row_id_column_id) { + // already projected - return the id + row_id_index = i; + break; + } + } + if (row_id_index.IsValid()) { + result.push_back(row_id_index.GetIndex()); + continue; + } + // row id is not yet projected - push it and return the new index + column_ids.push_back(ColumnIndex(row_id_column_id)); + if (!get.projection_ids.empty()) { + get.projection_ids.push_back(column_ids.size() - 1); + } + if (!get.types.empty()) { + get.types.push_back(row_id_column.type); + } + result.push_back(column_ids.size() - 1); + } + return result; +} +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index ce6cb0045..9bf4fcf8d 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -17,12 +17,14 @@ #include "duckdb/optimizer/filter_pullup.hpp" #include "duckdb/optimizer/filter_pushdown.hpp" #include "duckdb/optimizer/in_clause_rewriter.hpp" +#include "duckdb/optimizer/join_elimination.hpp" #include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" #include "duckdb/optimizer/join_order/join_order_optimizer.hpp" #include "duckdb/optimizer/limit_pushdown.hpp" #include "duckdb/optimizer/regex_range_filter.hpp" #include "duckdb/optimizer/remove_duplicate_groups.hpp" #include "duckdb/optimizer/remove_unused_columns.hpp" +#include "duckdb/optimizer/row_group_pruner.hpp" #include "duckdb/optimizer/rule/distinct_aggregate_optimizer.hpp" #include "duckdb/optimizer/rule/equal_or_null_simplification.hpp" #include "duckdb/optimizer/rule/in_clause_simplification.hpp" @@ -32,14 +34,17 @@ #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/optimizer/sum_rewriter.hpp" #include "duckdb/optimizer/topn_optimizer.hpp" +#include "duckdb/optimizer/topn_window_elimination.hpp" #include "duckdb/optimizer/unnest_rewriter.hpp" #include "duckdb/optimizer/late_materialization.hpp" +#include "duckdb/optimizer/common_subplan_optimizer.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/planner.hpp" namespace duckdb { Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), binder(binder), rewriter(context) { + rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); @@ -188,6 +193,11 @@ void Optimizer::RunBuiltInOptimizers() { plan = optimizer.Optimize(std::move(plan)); }); + RunOptimizer(OptimizerType::JOIN_ELIMINATION, [&]() { + JoinElimination join_elimination; + plan = join_elimination.Optimize(std::move(plan)); + }); + // rewrites UNNESTs in DelimJoins by moving them to the projection RunOptimizer(OptimizerType::UNNEST_REWRITER, [&]() { UnnestRewriter unnest_rewriter; @@ -225,12 +235,23 @@ void Optimizer::RunBuiltInOptimizers() { build_probe_side_optimizer.VisitOperator(*plan); }); + // convert common subplans into materialized CTEs + RunOptimizer(OptimizerType::COMMON_SUBPLAN, [&]() { + CommonSubplanOptimizer common_subplan_optimizer(*this); + plan = common_subplan_optimizer.Optimize(std::move(plan)); + }); + // pushes LIMIT below PROJECTION RunOptimizer(OptimizerType::LIMIT_PUSHDOWN, [&]() { LimitPushdown limit_pushdown; plan = limit_pushdown.Optimize(std::move(plan)); }); + RunOptimizer(OptimizerType::ROW_GROUP_PRUNER, [&]() { + RowGroupPruner row_group_pruner(context); + plan = row_group_pruner.Optimize(std::move(plan)); + }); + // perform sampling pushdown RunOptimizer(OptimizerType::SAMPLING_PUSHDOWN, [&]() { SamplingPushdown sampling_pushdown; @@ -257,6 +278,12 @@ void Optimizer::RunBuiltInOptimizers() { statistics_map = propagator.GetStatisticsMap(); }); + // rewrite row_number window function + filter on row_number to aggregate + RunOptimizer(OptimizerType::TOP_N_WINDOW_ELIMINATION, [&]() { + TopNWindowElimination topn_window_elimination(context, *this, &statistics_map); + plan = topn_window_elimination.Optimize(std::move(plan)); + }); + // remove duplicate aggregates RunOptimizer(OptimizerType::COMMON_AGGREGATE, [&]() { CommonAggregateOptimizer common_aggregate; diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp index 40ff4dc3c..f9dcbf84e 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_cross_product.cpp @@ -52,9 +52,9 @@ unique_ptr FilterPushdown::PushdownCrossProduct(unique_ptrchildren[1], left_bindings, right_bindings, join_expressions, conditions, arbitrary_expressions); // create the join from the join conditions - auto new_op = LogicalComparisonJoin::CreateJoin(GetContext(), join_type, join_ref_type, - std::move(op->children[0]), std::move(op->children[1]), - std::move(conditions), std::move(arbitrary_expressions)); + auto new_op = LogicalComparisonJoin::CreateJoin(join_type, join_ref_type, std::move(op->children[0]), + std::move(op->children[1]), std::move(conditions), + std::move(arbitrary_expressions)); // possible cases are: AnyJoin, ComparisonJoin, or Filter + ComparisonJoin if (op->has_estimated_cardinality) { diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp index 90dbbb823..ac4b6532a 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" namespace duckdb { unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { @@ -48,7 +49,9 @@ unique_ptr FilterPushdown::PushdownGet(unique_ptr(std::move(op)); + } //! We generate the table filters that will be executed during the table scan vector pushdown_results; diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp index 8370f4ca9..e2e4730d1 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp @@ -14,6 +14,7 @@ unique_ptr FilterPushdown::PushdownInnerJoin(unique_ptrCast(); D_ASSERT(join.join_type == JoinType::INNER); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } // inner join: gather all the conditions of the inner join and add to the filter list diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp index 9e56ed9d6..1ebf3cedd 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp @@ -78,6 +78,7 @@ unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr &right_bindings) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } FilterPushdown left_pushdown(optimizer, convert_mark_joins), right_pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp index 3d81c68c1..648c5e3d4 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_outer_join.cpp @@ -174,7 +174,6 @@ PushDownFiltersOnCoalescedEqualJoinKeys(vector> &filters, unique_ptr FilterPushdown::PushdownOuterJoin(unique_ptr op, unordered_set &left_bindings, unordered_set &right_bindings) { - if (op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { return FinishPushdown(std::move(op)); } diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp index 7d240e3f6..0b937fe25 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp @@ -12,6 +12,7 @@ using Filter = FilterPushdown::Filter; unique_ptr FilterPushdown::PushdownSemiAntiJoin(unique_ptr op) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } diff --git a/src/duckdb/src/optimizer/regex_range_filter.cpp b/src/duckdb/src/optimizer/regex_range_filter.cpp index fd9f98fe6..987c579af 100644 --- a/src/duckdb/src/optimizer/regex_range_filter.cpp +++ b/src/duckdb/src/optimizer/regex_range_filter.cpp @@ -16,7 +16,6 @@ namespace duckdb { unique_ptr RegexRangeFilter::Rewrite(unique_ptr op) { - for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { op->children[child_idx] = Rewrite(std::move(op->children[child_idx])); } diff --git a/src/duckdb/src/optimizer/remove_unused_columns.cpp b/src/duckdb/src/optimizer/remove_unused_columns.cpp index 20817633a..345e40e74 100644 --- a/src/duckdb/src/optimizer/remove_unused_columns.cpp +++ b/src/duckdb/src/optimizer/remove_unused_columns.cpp @@ -1,5 +1,7 @@ #include "duckdb/optimizer/remove_unused_columns.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/pair.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" @@ -20,6 +22,7 @@ #include "duckdb/planner/operator/logical_set_operation.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/function/scalar/struct_utils.hpp" +#include namespace duckdb { @@ -31,6 +34,9 @@ void BaseColumnPruner::ReplaceBinding(ColumnBinding current_binding, ColumnBindi D_ASSERT(colref.binding == current_binding); colref.binding = new_binding; } + auto record = std::move(colrefs->second); + column_references.erase(current_binding); + column_references.insert(make_pair(new_binding, std::move(record))); } } @@ -205,6 +211,19 @@ void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { // in this case we only need to project a single constant proj.expressions.push_back(make_uniq(Value::INTEGER(42))); } + RemoveUnusedColumns remove(binder, context); + auto tmp_deliver = remove.deliver_child; + for (idx_t idx = 0; idx < proj.expressions.size(); idx++) { + auto &expr = proj.expressions[idx]; + auto record = column_references.find(ColumnBinding(proj.table_index, idx)); + if (record != column_references.end() && !record->second.child_columns.empty()) { + remove.deliver_child = record->second.child_columns; + } + remove.VisitExpression(&expr); + remove.deliver_child = tmp_deliver; + } + remove.VisitOperator(*op.children[0]); + return; } // then recurse into the children of this projection RemoveUnusedColumns remove(binder, context); @@ -228,89 +247,13 @@ void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { } case LogicalOperatorType::LOGICAL_GET: { LogicalOperatorVisitor::VisitOperatorExpressions(op); - if (everything_referenced) { - return; - } auto &get = op.Cast(); - if (!get.function.projection_pushdown) { - return; - } - - auto final_column_ids = get.GetColumnIds(); - - // Create "selection vector" of all column ids - vector proj_sel; - for (idx_t col_idx = 0; col_idx < final_column_ids.size(); col_idx++) { - proj_sel.push_back(col_idx); - } - // Create a copy that we can use to match ids later - auto col_sel = proj_sel; - // Clear unused ids, exclude filter columns that are projected out immediately - ClearUnusedExpressions(proj_sel, get.table_index, false); - - vector> filter_expressions; - // for every table filter, push a column binding into the column references map to prevent the column from - // being projected out - for (auto &filter : get.table_filters.filters) { - optional_idx index; - for (idx_t i = 0; i < final_column_ids.size(); i++) { - if (final_column_ids[i].GetPrimaryIndex() == filter.first) { - index = i; - break; - } - } - if (!index.IsValid()) { - throw InternalException("Could not find column index for table filter"); - } - - auto column_type = get.GetColumnType(ColumnIndex(filter.first)); - - ColumnBinding filter_binding(get.table_index, index.GetIndex()); - auto column_ref = make_uniq(std::move(column_type), filter_binding); - auto filter_expr = filter.second->ToExpression(*column_ref); - if (filter_expr->IsScalar()) { - filter_expr = std::move(column_ref); - } - VisitExpression(&filter_expr); - filter_expressions.push_back(std::move(filter_expr)); - } - - // Clear unused ids, include filter columns that are projected out immediately - ClearUnusedExpressions(col_sel, get.table_index); - - // Now set the column ids in the LogicalGet using the "selection vector" - vector column_ids; - column_ids.reserve(col_sel.size()); - for (auto col_sel_idx : col_sel) { - auto entry = column_references.find(ColumnBinding(get.table_index, col_sel_idx)); - if (entry == column_references.end()) { - throw InternalException("RemoveUnusedColumns - could not find referenced column"); - } - ColumnIndex new_index(final_column_ids[col_sel_idx].GetPrimaryIndex(), entry->second.child_columns); - column_ids.emplace_back(new_index); - } - if (column_ids.empty()) { - // this generally means we are only interested in whether or not anything exists in the table (e.g. - // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not - // need to read any of the columns - column_ids.emplace_back(get.GetAnyColumn()); - } - get.SetColumnIds(std::move(column_ids)); - - if (!get.function.filter_prune) { - return; - } - // Now set the projection cols by matching the "selection vector" that excludes filter columns - // with the "selection vector" that includes filter columns - idx_t col_idx = 0; - get.projection_ids.clear(); - for (auto proj_sel_idx : proj_sel) { - for (; col_idx < col_sel.size(); col_idx++) { - if (proj_sel_idx == col_sel[col_idx]) { - get.projection_ids.push_back(col_idx); - break; - } - } + RemoveColumnsFromLogicalGet(get); + if (!op.children.empty()) { + // Some LOGICAL_GET operators (e.g., table in out functions) may have a + // child operator. So we recurse into it if it exists. + RemoveUnusedColumns remove(binder, context, true); + remove.VisitOperator(*op.children[0]); } return; } @@ -363,6 +306,93 @@ void RemoveUnusedColumns::VisitOperator(LogicalOperator &op) { } } +void RemoveUnusedColumns::RemoveColumnsFromLogicalGet(LogicalGet &get) { + if (everything_referenced) { + return; + } + if (!get.function.projection_pushdown) { + return; + } + + auto final_column_ids = get.GetColumnIds(); + + // Create "selection vector" of all column ids + vector proj_sel; + for (idx_t col_idx = 0; col_idx < final_column_ids.size(); col_idx++) { + proj_sel.push_back(col_idx); + } + // Create a copy that we can use to match ids later + auto col_sel = proj_sel; + // Clear unused ids, exclude filter columns that are projected out immediately + ClearUnusedExpressions(proj_sel, get.table_index, false); + + vector> filter_expressions; + // for every table filter, push a column binding into the column references map to prevent the column from + // being projected out + for (auto &filter : get.table_filters.filters) { + optional_idx index; + for (idx_t i = 0; i < final_column_ids.size(); i++) { + if (final_column_ids[i].GetPrimaryIndex() == filter.first) { + index = i; + break; + } + } + if (!index.IsValid()) { + throw InternalException("Could not find column index for table filter"); + } + + auto column_type = get.GetColumnType(ColumnIndex(filter.first)); + + ColumnBinding filter_binding(get.table_index, index.GetIndex()); + auto column_ref = make_uniq(std::move(column_type), filter_binding); + auto filter_expr = filter.second->ToExpression(*column_ref); + if (filter_expr->IsScalar()) { + filter_expr = std::move(column_ref); + } + VisitExpression(&filter_expr); + filter_expressions.push_back(std::move(filter_expr)); + } + + // Clear unused ids, include filter columns that are projected out immediately + ClearUnusedExpressions(col_sel, get.table_index); + + // Now set the column ids in the LogicalGet using the "selection vector" + vector column_ids; + column_ids.reserve(col_sel.size()); + for (idx_t idx = 0; idx < col_sel.size(); idx++) { + auto col_sel_idx = col_sel[idx]; + auto entry = column_references.find(ColumnBinding(get.table_index, idx)); + if (entry == column_references.end()) { + throw InternalException("RemoveUnusedColumns - could not find referenced column"); + } + ColumnIndex new_index(final_column_ids[col_sel_idx].GetPrimaryIndex(), entry->second.child_columns); + column_ids.emplace_back(new_index); + } + if (column_ids.empty()) { + // this generally means we are only interested in whether or not anything exists in the table (e.g. + // EXISTS(SELECT * FROM tbl)) in this case, we just scan the row identifier column as it means we do not + // need to read any of the columns + column_ids.emplace_back(get.GetAnyColumn()); + } + get.SetColumnIds(std::move(column_ids)); + + if (!get.function.filter_prune) { + return; + } + // Now set the projection cols by matching the "selection vector" that excludes filter columns + // with the "selection vector" that includes filter columns + idx_t col_idx = 0; + get.projection_ids.clear(); + for (auto proj_sel_idx : proj_sel) { + for (; col_idx < col_sel.size(); col_idx++) { + if (proj_sel_idx == col_sel[col_idx]) { + get.projection_ids.push_back(col_idx); + break; + } + } + } +} + bool BaseColumnPruner::HandleStructExtractRecursive(Expression &expr, optional_ptr &colref, vector &indexes) { if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { @@ -416,6 +446,14 @@ bool BaseColumnPruner::HandleStructExtract(Expression &expr) { return true; } +bool BaseColumnPruner::HandleStructPack(Expression &expr) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return false; + } + auto &function = expr.Cast(); + return function.function.name == "struct_pack"; +} + void MergeChildColumns(vector ¤t_child_columns, ColumnIndex &new_child_column) { if (current_child_columns.empty()) { // there's already a reference to the full column - we can't extract only a subfield @@ -480,6 +518,13 @@ void BaseColumnPruner::VisitExpression(unique_ptr *expression) { // already handled return; } + if (HandleStructPack(expr)) { + auto tmp_deliver = deliver_child; + deliver_child = vector(); + LogicalOperatorVisitor::VisitExpression(expression); + deliver_child = tmp_deliver; + return; + } // recurse LogicalOperatorVisitor::VisitExpression(expression); } @@ -487,7 +532,14 @@ void BaseColumnPruner::VisitExpression(unique_ptr *expression) { unique_ptr BaseColumnPruner::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { // add a reference to the entire column - AddBinding(expr); + if (deliver_child.empty()) { + AddBinding(expr); + } else { + for (auto &child_idx : deliver_child) { + AddBinding(expr, child_idx); + } + } + return nullptr; } diff --git a/src/duckdb/src/optimizer/row_group_pruner.cpp b/src/duckdb/src/optimizer/row_group_pruner.cpp new file mode 100644 index 000000000..b572ea896 --- /dev/null +++ b/src/duckdb/src/optimizer/row_group_pruner.cpp @@ -0,0 +1,193 @@ +#include "duckdb/optimizer/row_group_pruner.hpp" + +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/execution/operator/join/join_filter_pushdown.hpp" +#include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" +#include "duckdb/planner/operator/logical_order.hpp" +#include "duckdb/storage/table/row_group_reorderer.hpp" + +namespace duckdb { + +RowGroupPruner::RowGroupPruner(ClientContext &context_p) : context(context_p) { +} + +unique_ptr RowGroupPruner::Optimize(unique_ptr op) { + if (!TryOptimize(*op)) { + for (auto &child : op->children) { + child = Optimize(std::move(child)); + } + } + + return op; +} + +bool RowGroupPruner::TryOptimize(LogicalOperator &op) const { + optional_idx row_limit; + optional_idx row_offset; + + if (op.type != LogicalOperatorType::LOGICAL_LIMIT) { + return false; + } + + auto &logical_limit = op.Cast(); + GetLimitAndOffset(logical_limit, row_limit, row_offset); + auto logical_order = FindLogicalOrder(logical_limit); + if (!logical_order) { + return false; + } + + // Only reorder row groups if there are no additional limit operators since they could modify the order + reference current_op = *logical_order; + while (!current_op.get().children.empty()) { + if (current_op.get().children.size() > 1) { + return false; + } + + const auto op_type = current_op.get().type; + if (op_type == LogicalOperatorType::LOGICAL_LIMIT) { + return false; + } + if (op_type == LogicalOperatorType::LOGICAL_FILTER || + op_type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + row_limit.SetInvalid(); + row_offset.SetInvalid(); + } + current_op = *current_op.get().children[0]; + } + + column_t column_index; + auto logical_get = FindLogicalGet(*logical_order, column_index); + if (!logical_get) { + return false; + } + + if (!logical_get->table_filters.filters.empty()) { + // If there are filters, we only order the row groups but do not prune + row_limit.SetInvalid(); + row_offset.SetInvalid(); + } + + // We can use the RowGroupReorderer! + const auto &primary_order = logical_order->orders[0]; + auto options = + CreateRowGroupReordererOptions(row_limit, row_offset, primary_order, *logical_get, column_index, logical_limit); + logical_get->function.set_scan_order(std::move(options), logical_get->bind_data.get()); + + return true; +} + +void RowGroupPruner::GetLimitAndOffset(const LogicalLimit &logical_limit, optional_idx &row_limit, + optional_idx &row_offset) const { + if (logical_limit.limit_val.Type() == LimitNodeType::CONSTANT_VALUE) { + row_limit = logical_limit.limit_val.GetConstantValue(); + } else if (logical_limit.limit_val.Type() == LimitNodeType::UNSET) { + row_limit = 0; + } + + if (logical_limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { + row_offset = logical_limit.offset_val.GetConstantValue(); + } else if (logical_limit.offset_val.Type() == LimitNodeType::UNSET) { + row_offset = 0; + } +} + +optional_ptr RowGroupPruner::FindLogicalOrder(const LogicalLimit &logical_limit) const { + reference current_op = *logical_limit.children[0]; + while (current_op.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + current_op = *current_op.get().children[0]; + } + + if (current_op.get().type != LogicalOperatorType::LOGICAL_ORDER_BY) { + return nullptr; + } + + auto &logical_order = current_op.get().Cast(); + for (const auto &order : logical_order.orders) { + // We do not support any null-first orders as this requires unimplemented logic in the row group reorderer + if (order.null_order == OrderByNullType::NULLS_FIRST) { + return nullptr; + } + } + + auto order_column_type = logical_order.orders[0].expression->return_type; + if (!order_column_type.IsNumeric() && !order_column_type.IsTemporal() && + order_column_type != LogicalType::VARCHAR) { + return nullptr; + } + + if (logical_order.orders[0].expression->type != ExpressionType::BOUND_COLUMN_REF) { + return nullptr; + } + + return logical_order; +} + +optional_ptr RowGroupPruner::FindLogicalGet(const LogicalOrder &logical_order, + column_t &column_index) const { + const auto &primary_order = logical_order.orders[0]; + auto &colref = primary_order.expression->Cast(); + + vector columns {JoinFilterPushdownColumn {colref.binding}}; + vector pushdown_targets; + JoinFilterPushdownOptimizer::GetPushdownFilterTargets(*logical_order.children[0], std::move(columns), + pushdown_targets); + + if (pushdown_targets.empty()) { + return nullptr; + } + + D_ASSERT(pushdown_targets.size() == 1); + auto &logical_get = pushdown_targets.front().get; + + if (!logical_get.function.set_scan_order) { + return nullptr; + } + + auto col_idx = pushdown_targets[0].columns[0].probe_column_index.column_index; + column_index = logical_get.GetColumnIds()[col_idx].GetPrimaryIndex(); + + return logical_get; +} + +unique_ptr +RowGroupPruner::CreateRowGroupReordererOptions(const optional_idx row_limit, const optional_idx row_offset, + const BoundOrderByNode &primary_order, const LogicalGet &logical_get, + const column_t column_index, LogicalLimit &logical_limit) const { + auto &colref = primary_order.expression->Cast(); + auto column_type = + colref.return_type == LogicalType::VARCHAR ? OrderByColumnType::STRING : OrderByColumnType::NUMERIC; + auto order_type = primary_order.type == OrderType::ASCENDING ? RowGroupOrderType::ASC : RowGroupOrderType::DESC; + auto order_by = order_type == RowGroupOrderType::ASC ? OrderByStatistics::MIN : OrderByStatistics::MAX; + optional_idx combined_limit = row_limit.IsValid() + ? row_limit.GetIndex() + (row_offset.IsValid() ? row_offset.GetIndex() : 0) + : optional_idx(); + + if (row_offset.IsValid() && row_offset.GetIndex() > 0 && logical_get.function.get_partition_stats) { + // Try to prune with offset + GetPartitionStatsInput input(logical_get.function, logical_get.bind_data.get()); + auto partition_stats = logical_get.function.get_partition_stats(context, input); + if (!partition_stats.empty()) { + auto offset_puning_result = RowGroupReorderer::GetOffsetAfterPruning( + order_by, column_type, order_type, column_index, row_offset.GetIndex(), partition_stats); + if (offset_puning_result.pruned_row_group_count > 0) { + // We can prune row groups and reduce the offset + logical_limit.offset_val = + BoundLimitNode::ConstantValue(NumericCast(offset_puning_result.offset_remainder)); + + if (combined_limit.IsValid()) { + combined_limit = row_limit.GetIndex() + offset_puning_result.offset_remainder; + } + + return make_uniq(column_index, order_by, order_type, column_type, combined_limit, + offset_puning_result.pruned_row_group_count); + } + } + } + // Only sort row groups by primary order column and prune with limit if set + return make_uniq(column_index, order_by, order_type, column_type, combined_limit, + NumericCast(0)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp index dc778cfff..5fe22818e 100644 --- a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp @@ -17,7 +17,6 @@ ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &r unique_ptr ComparisonSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { - auto &expr = bindings[0].get().Cast(); auto &constant_expr = bindings[1].get(); bool column_ref_left = expr.left.get() != &constant_expr; diff --git a/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp new file mode 100644 index 000000000..09376c3d3 --- /dev/null +++ b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp @@ -0,0 +1,127 @@ +#include "duckdb/optimizer/rule/constant_order_normalization.hpp" + +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +class RecursiveFunctionExpressionMatcher : public ExpressionMatcher { +public: + explicit RecursiveFunctionExpressionMatcher(vector> func_matchers) + : func_matchers(std::move(func_matchers)) { + } + bool Match(Expression &expr, vector> &bindings) override { + FunctionExpressionMatcher *target_matcher = nullptr; + for (const auto &matcher : func_matchers) { + if (matcher->Match(expr, bindings)) { + target_matcher = matcher.get(); + break; + } + } + if (target_matcher == nullptr) { + return false; + } + bindings.clear(); + RecursiveMatch(target_matcher, expr, bindings); + bindings.push_back(expr); + return true; + } + +private: + void RecursiveMatch(FunctionExpressionMatcher *func_matcher, Expression &expr, + vector> &bindings) { + vector> curr_bindings; + if (func_matcher->Match(expr, curr_bindings)) { + auto &func_expr = expr.Cast(); + for (auto &child : func_expr.children) { + RecursiveMatch(func_matcher, *(child.get()), bindings); + } + } else { + bindings.push_back(expr); + } + } + + vector> func_matchers; +}; + +ConstantOrderNormalizationRule::ConstantOrderNormalizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // '+' and '*' satisfy commutative law and associative law. + auto add_matcher = make_uniq(); + add_matcher->function = make_uniq("+"); + add_matcher->type = make_uniq(); + auto left_expression_matcher = make_uniq(); + auto right_expression_matcher = make_uniq(); + left_expression_matcher->type = make_uniq(); + right_expression_matcher->type = make_uniq(); + add_matcher->matchers.push_back(std::move(left_expression_matcher)); + add_matcher->matchers.push_back(std::move(right_expression_matcher)); + add_matcher->policy = SetMatcher::Policy::ORDERED; + + auto multiply_matcher = make_uniq(); + multiply_matcher->function = make_uniq("*"); + multiply_matcher->type = make_uniq(); + left_expression_matcher = make_uniq(); + right_expression_matcher = make_uniq(); + left_expression_matcher->type = make_uniq(); + right_expression_matcher->type = make_uniq(); + multiply_matcher->matchers.push_back(std::move(left_expression_matcher)); + multiply_matcher->matchers.push_back(std::move(right_expression_matcher)); + multiply_matcher->policy = SetMatcher::Policy::ORDERED; + + vector> func_matchers; + func_matchers.push_back(std::move(add_matcher)); + func_matchers.push_back(std::move(multiply_matcher)); + auto op = make_uniq(std::move(func_matchers)); + root = std::move(op); +} + +unique_ptr ConstantOrderNormalizationRule::Apply(LogicalOperator &op, + vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings.back().get().Cast(); + + // Put all constant expressions in front. + vector> ordered_bindings; + vector> remain_bindings; + idx_t last_constant_position = 0; + for (idx_t i = 0; i < bindings.size() - 1; ++i) { + if (bindings[i].get().IsFoldable()) { + ordered_bindings.push_back(bindings[i]); + last_constant_position = i; + } else { + remain_bindings.push_back(bindings[i]); + } + } + + if (ordered_bindings.size() <= 1 || last_constant_position == ordered_bindings.size() - 1) { + return nullptr; + } + ordered_bindings.insert(ordered_bindings.end(), remain_bindings.begin(), remain_bindings.end()); + + // Reconstruct the expression. + FunctionBinder binder(rewriter.context); + ErrorData error; + unique_ptr new_root = ordered_bindings[0].get().Copy(); + vector> children; + children.push_back(std::move(new_root)); + for (idx_t i = 1; i < ordered_bindings.size(); ++i) { + // Right child. + children.push_back(ordered_bindings[i].get().Copy()); + new_root = + binder.BindScalarFunction(DEFAULT_SCHEMA, root.function.name, std::move(children), error, root.is_operator); + if (!new_root) { + error.Throw(); + } + children.clear(); + // Left child. + children.push_back(std::move(new_root)); + } + + D_ASSERT(children.size() == 1); + D_ASSERT(children[0]->return_type == root.return_type); + + return std::move(children[0]); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp index 392574de3..9e760a2f3 100644 --- a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp @@ -246,7 +246,7 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_GREATERTHAN: - // date_trunc(part, column) <= constant_rhs --> column <= date_trunc(part, date_add(constant_rhs, + // date_trunc(part, column) <= constant_rhs --> column < date_trunc(part, date_add(constant_rhs, // INTERVAL 1 part)) // date_trunc(part, column) > constant_rhs --> column >= date_trunc(part, date_add(constant_rhs, // INTERVAL 1 part)) @@ -265,13 +265,19 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v expr.left = std::move(trunc); } - // If this is a >, we need to change it to >= for correctness. + // > needs to become >=, and <= needs to become <. if (rhs_comparison_type == ExpressionType::COMPARE_GREATERTHAN) { if (col_is_lhs) { expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHANOREQUALTO); } else { expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHANOREQUALTO); } + } else { + if (col_is_lhs) { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_LESSTHAN); + } else { + expr.SetExpressionTypeUnsafe(ExpressionType::COMPARE_GREATERTHAN); + } } changes_made = true; diff --git a/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp index 7eeae1d34..c6031a01a 100644 --- a/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp +++ b/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp @@ -17,7 +17,7 @@ unique_ptr DistinctAggregateOptimizer::Apply(ClientContext &context, // no DISTINCT defined return nullptr; } - if (aggr.function.distinct_dependent == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { + if (aggr.function.GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { // not a distinct-sensitive aggregate but we have an DISTINCT modifier - remove it aggr.aggr_type = AggregateType::NON_DISTINCT; changes_made = true; @@ -47,7 +47,7 @@ unique_ptr DistinctWindowedOptimizer::Apply(ClientContext &context, // not an aggregate return nullptr; } - if (wexpr.aggregate->distinct_dependent == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { + if (wexpr.aggregate->GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { // not a distinct-sensitive aggregate but we have an DISTINCT modifier - remove it wexpr.distinct = false; changes_made = true; diff --git a/src/duckdb/src/optimizer/rule/enum_comparison.cpp b/src/duckdb/src/optimizer/rule/enum_comparison.cpp index 5dfcfaf30..0553285eb 100644 --- a/src/duckdb/src/optimizer/rule/enum_comparison.cpp +++ b/src/duckdb/src/optimizer/rule/enum_comparison.cpp @@ -47,7 +47,6 @@ bool AreMatchesPossible(LogicalType &left, LogicalType &right) { } unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); auto &left_child = bindings[1].get().Cast(); auto &right_child = bindings[3].get().Cast(); diff --git a/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp index 8f6435e1f..a87d66d08 100644 --- a/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp +++ b/src/duckdb/src/optimizer/rule/ordered_aggregate_optimizer.cpp @@ -17,12 +17,14 @@ OrderedAggregateOptimizer::OrderedAggregateOptimizer(ExpressionRewriter &rewrite } unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, BoundAggregateExpression &aggr, - vector> &groups, bool &changes_made) { + vector> &groups, + optional_ptr> grouping_sets, + bool &changes_made) { if (!aggr.order_bys) { // no ORDER BYs defined return nullptr; } - if (aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { + if (aggr.function.GetOrderDependent() == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { // not an order dependent aggregate but we have an ORDER BY clause - remove it aggr.order_bys.reset(); changes_made = true; @@ -30,7 +32,7 @@ unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, } // Remove unnecessary ORDER BY clauses and return if nothing remains - if (aggr.order_bys->Simplify(groups)) { + if (aggr.order_bys->Simplify(groups, grouping_sets)) { aggr.order_bys.reset(); changes_made = true; return nullptr; @@ -90,7 +92,8 @@ unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, unique_ptr OrderedAggregateOptimizer::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { auto &aggr = bindings[0].get().Cast(); - return Apply(rewriter.context, aggr, op.Cast().groups, changes_made); + return Apply(rewriter.context, aggr, op.Cast().groups, op.Cast().grouping_sets, + changes_made); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp index 24786867b..3a0697e99 100644 --- a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp +++ b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp @@ -184,6 +184,13 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< if (!escaped_like_string.exists) { return nullptr; } + + // if regexp had options, remove them so the new Contains Expression can be matched for other optimizers. + if (root.children.size() == 3) { + root.children.pop_back(); + D_ASSERT(root.children.size() == 2); + } + auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); auto contains = make_uniq(root.return_type, GetStringContains(), std::move(root.children), nullptr); diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp index d0e24dac6..8547183a2 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp @@ -15,11 +15,11 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundAggreg stats.push_back(stat->Copy()); } } - if (!aggr.function.statistics) { + if (!aggr.function.HasStatisticsCallback()) { return nullptr; } AggregateStatisticsInput input(aggr.bind_info.get(), stats, node_stats.get()); - return aggr.function.statistics(context, aggr, input); + return aggr.function.GetStatisticsCallback()(context, aggr, input); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp index 2a0263450..336fe9055 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp @@ -15,11 +15,11 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundFuncti stats.push_back(stat->Copy()); } } - if (!func.function.statistics) { + if (!func.function.HasStatisticsCallback()) { return nullptr; } FunctionStatisticsInput input(func, func.bind_info.get(), stats, &expr_ptr); - return func.function.statistics(context, input); + return func.function.GetStatisticsCallback()(context, input); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp index 831317687..45a961cab 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp @@ -1,4 +1,13 @@ +#include "duckdb/common/assert.hpp" +#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/enums/tuple_data_layout_enums.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/function/partition_stats.hpp" #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/planner/operator/logical_dummy_scan.hpp" @@ -7,19 +16,139 @@ #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" namespace duckdb { +namespace { + +struct ValueComparator { + virtual ~ValueComparator() = default; + virtual bool Compare(Value &lhs, Value &rhs) const = 0; + virtual Value GetVal(BaseStatistics &stats) const = 0; +}; + +template +struct MinValueComp : public ValueComparator { + bool Compare(Value &lhs, Value &rhs) const override { + return lhs < rhs; + } + Value GetVal(BaseStatistics &stats) const override { + return StatsType::Min(stats); + } +}; + +template +struct MaxValueComp : public ValueComparator { + bool Compare(Value &lhs, Value &rhs) const override { + return lhs > rhs; + } + Value GetVal(BaseStatistics &stats) const override { + return StatsType::Max(stats); + } +}; + +template +unique_ptr GetComparator(const string &fun_name) { + if (fun_name == "min") { + return make_uniq>(); + } + D_ASSERT(fun_name == "max"); + return make_uniq>(); +} + +unique_ptr GetComparator(const string &fun_name, const LogicalType &type) { + if (type == LogicalType::VARCHAR) { + return GetComparator(fun_name); + } else if (type.IsNumeric() || type.IsTemporal()) { + return GetComparator(fun_name); + } + return nullptr; +} + +bool TryGetValueFromStats(const PartitionStatistics &stats, const column_t column_index, + const ValueComparator &comparator, Value &result) { + if (!stats.partition_row_group) { + return false; + } + auto column_stats = stats.partition_row_group->GetColumnStatistics(column_index); + if (!stats.partition_row_group->MinMaxIsExact(*column_stats)) { + return false; + } + if (column_stats->GetStatsType() == StatisticsType::NUMERIC_STATS) { + if (!NumericStats::HasMinMax(*column_stats)) { + // TODO: This also returns if an entire row group is null. In that case, we could skip/compare null + return false; + } + } else { + D_ASSERT(column_stats->GetStatsType() == StatisticsType::STRING_STATS); + if (StringStats::Min(*column_stats) > StringStats::Max(*column_stats)) { + // No min/max statistics availabe + return false; + } + } + result = comparator.GetVal(*column_stats); + return true; +} + +} // namespace + void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_ptr &node_ptr) { if (!aggr.groups.empty()) { // not possible with groups return; } + // check if all aggregates are COUNT(*), MIN or MAX + vector count_star_idxs; + vector min_max_bindings; + vector> comparators; + + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + auto &aggr_ref = aggr.expressions[i]; + if (aggr_ref->GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE) { + // not an aggregate + return; + } + auto &aggr_expr = aggr_ref->Cast(); + if (aggr_expr.filter) { + // aggregate has a filter - bail + return; + } + const string &fun_name = aggr_expr.function.name; + if (fun_name == "min" || fun_name == "max") { + if (aggr_expr.children.size() != 1 || aggr_expr.children[0]->type != ExpressionType::BOUND_COLUMN_REF) { + return; + } + const auto &col_ref = aggr_expr.children[0]->Cast(); + min_max_bindings.push_back(col_ref.binding); + auto comparator = GetComparator(fun_name, col_ref.return_type); + if (!comparator) { + // Type has no min max statistics + return; + } + comparators.push_back(std::move(comparator)); + } else if (fun_name == "count_star") { + count_star_idxs.push_back(i); + } else { + // aggregate is not count star, min or max - bail + return; + } + } + // skip any projections reference child_ref = *aggr.children[0]; while (child_ref.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + for (auto &binding : min_max_bindings) { + auto &expr = child_ref.get().expressions[binding.column_index]; + if (expr->type != ExpressionType::BOUND_COLUMN_REF) { + return; + } + binding = expr->Cast().binding; + } child_ref = *child_ref.get().children[0]; } + if (child_ref.get().type != LogicalOperatorType::LOGICAL_GET) { // child must be a LOGICAL_GET return; @@ -33,22 +162,11 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p // we cannot do this if the GET has filters return; } - // check if all aggregates are COUNT(*) - for (auto &aggr_ref : aggr.expressions) { - if (aggr_ref->GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE) { - // not an aggregate - return; - } - auto &aggr_expr = aggr_ref->Cast(); - if (aggr_expr.function.name != "count_star") { - // aggregate is not count star - bail - return; - } - if (aggr_expr.filter) { - // aggregate has a filter - bail - return; - } + if (get.extra_info.sample_options) { + // only use row group statistics if we query the whole table + return; } + // we can do the rewrite! get the stats GetPartitionStatsInput input(get.function, get.bind_data.get()); auto partition_stats = get.function.get_partition_stats(context, input); @@ -56,27 +174,59 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p // no partition stats found return; } - idx_t count = 0; - for (auto &stats : partition_stats) { - if (stats.count_type == CountType::COUNT_APPROXIMATE) { - // we cannot get an exact count - return; + + vector types; + vector> agg_results; + + if (!min_max_bindings.empty()) { + // Execute min/max aggregates on partition statistics + for (idx_t agg_idx = 0; agg_idx < min_max_bindings.size(); agg_idx++) { + const auto &binding = min_max_bindings[agg_idx]; + const column_t column_index = get.GetColumnIds()[binding.column_index].GetPrimaryIndex(); + auto &comparator = comparators[agg_idx]; + + Value agg_result; + if (!TryGetValueFromStats(partition_stats[0], column_index, *comparator, agg_result)) { + return; + } + for (idx_t partition_idx = 1; partition_idx < partition_stats.size(); partition_idx++) { + Value rhs; + if (!TryGetValueFromStats(partition_stats[partition_idx], column_index, *comparator, rhs)) { + return; + } + if (!comparator->Compare(agg_result, rhs)) { + agg_result = rhs; + } + } + types.push_back(agg_result.GetTypeMutable()); + auto expr = make_uniq(agg_result); + agg_results.push_back(std::move(expr)); + } + } + if (!count_star_idxs.empty()) { + // Execute count_star aggregates on partition statistics + idx_t count = 0; + for (const auto &stats : partition_stats) { + if (stats.count_type == CountType::COUNT_APPROXIMATE) { + // we cannot get an exact count + return; + } + count += stats.count; + } + for (const auto count_star_idx : count_star_idxs) { + auto count_result = make_uniq(Value::BIGINT(NumericCast(count))); + agg_results.emplace(agg_results.begin() + NumericCast(count_star_idx), std::move(count_result)); + types.insert(types.begin() + NumericCast(count_star_idx), LogicalType::BIGINT); } - count += stats.count; } - // we got an exact count - replace the entire aggregate with a scan of the result - vector types; - vector> count_results; - for (idx_t aggregate_index = 0; aggregate_index < aggr.expressions.size(); ++aggregate_index) { - auto count_result = make_uniq(Value::BIGINT(NumericCast(count))); - count_result->SetAlias(aggr.expressions[aggregate_index]->GetName()); - count_results.push_back(std::move(count_result)); - types.push_back(LogicalType::BIGINT); + // Set column names + for (idx_t expr_idx = 0; expr_idx < agg_results.size(); expr_idx++) { + agg_results[expr_idx]->SetAlias(aggr.expressions[expr_idx]->GetAlias()); } vector>> expressions; - expressions.push_back(std::move(count_results)); + expressions.push_back(std::move(agg_results)); auto expression_get = make_uniq(aggr.aggregate_index, std::move(types), std::move(expressions)); expression_get->children.push_back(make_uniq(aggr.group_index)); diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp index b01f7d704..a7aeec267 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp @@ -46,7 +46,12 @@ FilterPropagateResult StatisticsPropagator::PropagateTableFilter(ColumnBinding s // replace BoundColumnRefs with BoundRefs ExpressionFilter::ReplaceExpressionRecursive(filter_expr, *colref, ExpressionType::BOUND_COLUMN_REF); expr_filter.expr = std::move(filter_expr); - return propagate_result; + + // If we were able to prune solely based on the expression, return that result + if (propagate_result != FilterPropagateResult::NO_PRUNING_POSSIBLE) { + return propagate_result; + } + // Otherwise, check the statistics } return filter.CheckStatistics(stats); } diff --git a/src/duckdb/src/optimizer/topn_optimizer.cpp b/src/duckdb/src/optimizer/topn_optimizer.cpp index e42c748cb..929f5eb62 100644 --- a/src/duckdb/src/optimizer/topn_optimizer.cpp +++ b/src/duckdb/src/optimizer/topn_optimizer.cpp @@ -5,12 +5,15 @@ #include "duckdb/planner/operator/logical_limit.hpp" #include "duckdb/planner/operator/logical_order.hpp" #include "duckdb/planner/operator/logical_top_n.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/constant_filter.hpp" #include "duckdb/planner/filter/dynamic_filter.hpp" +#include "duckdb/planner/filter/null_filter.hpp" #include "duckdb/planner/filter/optional_filter.hpp" #include "duckdb/execution/operator/join/join_filter_pushdown.hpp" #include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/storage/table/scan_state.hpp" namespace duckdb { @@ -39,6 +42,9 @@ bool TopN::CanOptimize(LogicalOperator &op, optional_ptr context) if (child_op->has_estimated_cardinality) { // only check if we should switch to full sorting if we have estimated cardinality auto constant_limit = static_cast(limit.limit_val.GetConstantValue()); + if (limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { + constant_limit += static_cast(limit.offset_val.GetConstantValue()); + } auto child_card = static_cast(child_op->estimated_cardinality); // if the limit is > 0.7% of the child cardinality, sorting the whole table is faster @@ -60,11 +66,7 @@ bool TopN::CanOptimize(LogicalOperator &op, optional_ptr context) void TopN::PushdownDynamicFilters(LogicalTopN &op) { // pushdown dynamic filters through the Top-N operator - if (op.orders[0].null_order == OrderByNullType::NULLS_FIRST) { - // FIXME: not supported for NULLS FIRST quite yet - // we can support NULLS FIRST by doing (x IS NULL) OR [boundary value] - return; - } + bool nulls_first = op.orders[0].null_order == OrderByNullType::NULLS_FIRST; auto &type = op.orders[0].expression->return_type; if (!TypeIsIntegral(type.InternalType()) && type.id() != LogicalTypeId::VARCHAR) { // only supported for integral types currently @@ -117,7 +119,14 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { // create the actual dynamic filter auto dynamic_filter = make_uniq(filter_data); - auto optional_filter = make_uniq(std::move(dynamic_filter)); + unique_ptr pushed_filter = std::move(dynamic_filter); + if (nulls_first) { + auto or_filter = make_uniq(); + or_filter->child_filters.push_back(make_uniq()); + or_filter->child_filters.push_back(std::move(pushed_filter)); + pushed_filter = std::move(or_filter); + } + auto optional_filter = make_uniq(std::move(pushed_filter)); // push the filter into the table scan auto &column_index = get.GetColumnIds()[col_idx]; @@ -127,7 +136,6 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { unique_ptr TopN::Optimize(unique_ptr op) { if (CanOptimize(*op, &context)) { - vector> projections; // traverse operator tree and collect all projection nodes until we reach diff --git a/src/duckdb/src/optimizer/topn_window_elimination.cpp b/src/duckdb/src/optimizer/topn_window_elimination.cpp new file mode 100644 index 000000000..c5544d21f --- /dev/null +++ b/src/duckdb/src/optimizer/topn_window_elimination.cpp @@ -0,0 +1,979 @@ +#include "duckdb/optimizer/topn_window_elimination.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/optimizer/late_materialization_helper.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/struct_functions.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +namespace { + +idx_t GetGroupIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().group_index; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return op->children[0]->GetTableIndex()[0]; + } + return op->GetTableIndex()[0]; +} + +idx_t GetAggregateIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().aggregate_index; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return op->children[0]->GetTableIndex()[0]; + } + return op->GetTableIndex()[0]; +} + +LogicalType GetAggregateType(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNNEST: { + const auto &logical_unnest = op->Cast(); + const idx_t unnest_offset = logical_unnest.children[0]->types.size(); + return logical_unnest.types[unnest_offset]; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + const auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + return logical_aggregate.types[aggregate_column_idx]; + } + default: { + throw InternalException("Unnest or aggregate expected to extract aggregate type."); + } + } +} + +vector ExtractReturnTypes(const vector> &exprs) { + vector types; + types.reserve(exprs.size()); + for (const auto &expr : exprs) { + types.push_back(expr->return_type); + } + return types; +} + +bool BindingsReferenceRowNumber(const vector &bindings, const LogicalWindow &window) { + for (const auto &binding : bindings) { + if (binding.table_index == window.window_index) { + return true; + } + } + return false; +} + +ColumnBinding GetRowNumberColumnBinding(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNNEST: { + const auto column_bindings = op->GetColumnBindings(); + const idx_t row_number_offset = op->children[0]->types.size() + 1; + D_ASSERT(op->types.size() == row_number_offset + 1); + return column_bindings[row_number_offset]; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + const auto &projection = op->Cast(); + return {projection.table_index, projection.types.size() - 1}; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + const auto &join = op->Cast(); + D_ASSERT(!join.right_projection_map.empty()); + const auto child_bindings = op->GetColumnBindings(); + return child_bindings[child_bindings.size() - 1]; + } + default: { + throw InternalException("Operator type not supported."); + } + } +} + +idx_t TraverseAndFindAggregateOffset(const unique_ptr &op) { + reference current_op = *op; + while (current_op.get().type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + D_ASSERT(!current_op.get().children.empty()); + current_op = *current_op.get().children[0]; + } + const auto &aggregate = current_op.get().Cast(); + return aggregate.groups.size(); +} + +} // namespace + +TopNWindowElimination::TopNWindowElimination(ClientContext &context_p, Optimizer &optimizer, + optional_ptr>> stats_p) + : context(context_p), optimizer(optimizer), stats(stats_p) { +} + +unique_ptr TopNWindowElimination::Optimize(unique_ptr op) { + auto &extension_manager = context.db->GetExtensionManager(); + if (!extension_manager.ExtensionIsLoaded("core_functions")) { + return op; + } + + ColumnBindingReplacer replacer; + op = OptimizeInternal(std::move(op), replacer); + if (!replacer.replacement_bindings.empty()) { + replacer.VisitOperator(*op); + } + return op; +} + +unique_ptr TopNWindowElimination::OptimizeInternal(unique_ptr op, + ColumnBindingReplacer &replacer) { + if (!CanOptimize(*op)) { + // Traverse through query plan to find grouped top-n pattern + if (op->children.size() > 1) { + // If an operator has multiple children, we do not want them to overwrite each other's stop operator. + // Thus, first update only the column binding in op, then set op as the new stop operator. + for (auto &child : op->children) { + ColumnBindingReplacer r2; + child = OptimizeInternal(std::move(child), r2); + + if (!r2.replacement_bindings.empty()) { + r2.VisitOperator(*op); + replacer.replacement_bindings.insert(replacer.replacement_bindings.end(), + r2.replacement_bindings.begin(), + r2.replacement_bindings.end()); + replacer.stop_operator = op; + } + } + } else if (!op->children.empty()) { + op->children[0] = OptimizeInternal(std::move(op->children[0]), replacer); + } + + return op; + } + // We have made sure that this is an operator sequence of filter -> N optional projections -> window + auto &filter = op->Cast(); + reference child = *filter.children[0]; + + // Get bindings and types from filter to use in top-most operator later + const auto topmost_bindings = filter.GetColumnBindings(); + auto new_bindings = TraverseProjectionBindings(topmost_bindings, child); + + D_ASSERT(child.get().type == LogicalOperatorType::LOGICAL_WINDOW); + auto &window = child.get().Cast(); + const idx_t window_idx = window.window_index; + + // Map the input column offsets of the group columns to the output offset if there are projections on the group + // We use an ordered map here because we need to iterate over them in order later + map group_projection_idxs; + auto aggregate_payload = GenerateAggregatePayload(new_bindings, window, group_projection_idxs); + auto params = ExtractOptimizerParameters(window, filter, new_bindings, aggregate_payload); + + unique_ptr late_mat_lhs = nullptr; + if (params.payload_type == TopNPayloadType::STRUCT_PACK) { + // Try circumventing struct-packing with late materialization + late_mat_lhs = TryPrepareLateMaterialization(window, aggregate_payload); + if (late_mat_lhs) { + params.payload_type = TopNPayloadType::SINGLE_COLUMN; + } + } + + // Optimize window children + window.children[0] = Optimize(std::move(window.children[0])); + + op = CreateAggregateOperator(window, std::move(aggregate_payload), params); + op = TryCreateUnnestOperator(std::move(op), params); + op = CreateProjectionOperator(std::move(op), params, group_projection_idxs); + + D_ASSERT(op->type != LogicalOperatorType::LOGICAL_UNNEST); + + if (late_mat_lhs) { + op = ConstructJoin(std::move(late_mat_lhs), std::move(op), group_projection_idxs.size(), params); + } + + UpdateTopmostBindings(window_idx, op, group_projection_idxs, topmost_bindings, new_bindings, replacer); + replacer.stop_operator = op.get(); + + RemoveUnusedColumns unused_optimizer(optimizer.binder, optimizer.context, true); + unused_optimizer.VisitOperator(*op); + + return unique_ptr(std::move(op)); +} + +unique_ptr +TopNWindowElimination::CreateAggregateExpression(vector> aggregate_params, + const bool requires_arg, + const TopNWindowEliminationParameters ¶ms) const { + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + + // If the value column can be null, we must use the nulls_last function to follow null ordering semantics + const bool change_to_arg = !requires_arg && params.can_be_null && params.limit > 1; + if (change_to_arg) { + // Copy value as argument + aggregate_params.insert(aggregate_params.begin() + 1, aggregate_params[0]->Copy()); + } + + D_ASSERT(params.order_type == OrderType::ASCENDING || params.order_type == OrderType::DESCENDING); + string fun_name = requires_arg || change_to_arg ? "arg_" : ""; + fun_name += params.order_type == OrderType::ASCENDING ? "min" : "max"; + fun_name += params.can_be_null && (requires_arg || change_to_arg) ? "_nulls_last" : ""; + + auto &fun_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, fun_name); + const auto fun = fun_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(aggregate_params)); + return function_binder.BindAggregateFunction(fun, std::move(aggregate_params)); +} + +unique_ptr +TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const { + auto &window_expr = window.expressions[0]->Cast(); + D_ASSERT(window_expr.orders.size() == 1); + + vector> aggregate_params; + aggregate_params.reserve(3); + + const bool use_arg = !args.empty(); + if (args.size() == 1) { + aggregate_params.push_back(std::move(args[0])); + } else if (args.size() > 1) { + // For more than one arg, we must use struct pack + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + auto &struct_pack_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_pack"); + const auto struct_pack_fun = + struct_pack_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(args)); + auto struct_pack_expr = function_binder.BindScalarFunction(struct_pack_fun, std::move(args)); + aggregate_params.push_back(std::move(struct_pack_expr)); + } + + aggregate_params.push_back(std::move(window_expr.orders[0].expression)); + if (params.limit > 1) { + aggregate_params.push_back(std::move(make_uniq(Value::BIGINT(params.limit)))); + } + + auto aggregate_expr = CreateAggregateExpression(std::move(aggregate_params), use_arg, params); + + vector> select_list; + select_list.push_back(std::move(aggregate_expr)); + + auto aggregate = make_uniq(optimizer.binder.GenerateTableIndex(), + optimizer.binder.GenerateTableIndex(), std::move(select_list)); + aggregate->groupings_index = optimizer.binder.GenerateTableIndex(); + aggregate->groups = std::move(window_expr.partitions); + aggregate->children.push_back(std::move(window.children[0])); + aggregate->ResolveOperatorTypes(); + + // Add group statistics to allow for perfect hash aggregation if applicable + aggregate->group_stats.resize(aggregate->groups.size()); + for (idx_t i = 0; i < aggregate->groups.size(); i++) { + auto &group = aggregate->groups[i]; + if (group->type == ExpressionType::BOUND_COLUMN_REF) { + auto &column_ref = group->Cast(); + if (stats) { + auto group_stats = stats->find(column_ref.binding); + if (group_stats == stats->end()) { + continue; + } + aggregate->group_stats[i] = group_stats->second->ToUnique(); + } + } + } + + return unique_ptr(std::move(aggregate)); +} + +unique_ptr +TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const { + // Create unnest(generate_series(1, array_length(column_ref, 1))) function to generate row ids + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + + // array_length + auto &array_length_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "array_length"); + vector> array_length_exprs; + array_length_exprs.push_back(std::move(aggregate_column_ref)); + array_length_exprs.push_back(make_uniq(1)); + + const auto array_length_fun = array_length_entry.functions.GetFunctionByArguments( + context, {array_length_exprs[0]->return_type, array_length_exprs[1]->return_type}); + auto bound_array_length_fun = function_binder.BindScalarFunction(array_length_fun, std::move(array_length_exprs)); + + // generate_series + auto &generate_series_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "generate_series"); + + vector> generate_series_exprs; + generate_series_exprs.push_back(make_uniq(1)); + generate_series_exprs.push_back(std::move(bound_array_length_fun)); + + const auto generate_series_fun = generate_series_entry.functions.GetFunctionByArguments( + context, {generate_series_exprs[0]->return_type, generate_series_exprs[1]->return_type}); + auto bound_generate_series_fun = + function_binder.BindScalarFunction(generate_series_fun, std::move(generate_series_exprs)); + + // unnest + auto unnest_row_number_expr = make_uniq(LogicalType::BIGINT); + unnest_row_number_expr->alias = "row_number"; + unnest_row_number_expr->child = std::move(bound_generate_series_fun); + + return unique_ptr(std::move(unnest_row_number_expr)); +} + +unique_ptr +TopNWindowElimination::TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY); + + auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + LogicalType aggregate_type = logical_aggregate.types[aggregate_column_idx]; + + if (params.limit <= 1) { + // LIMIT 1 -> we do not need to unnest + return std::move(op); + } + + // Create unnest expression for aggregate args + const auto aggregate_bindings = logical_aggregate.GetColumnBindings(); + auto aggregate_column_ref = + make_uniq(aggregate_type, aggregate_bindings[aggregate_column_idx]); + + vector> unnest_exprs; + + auto unnest_aggregate = make_uniq(ListType::GetChildType(aggregate_type)); + unnest_aggregate->child = aggregate_column_ref->Copy(); + unnest_exprs.push_back(std::move(unnest_aggregate)); + + if (params.include_row_number) { + // Create row number expression + unnest_exprs.push_back(CreateRowNumberGenerator(std::move(aggregate_column_ref))); + } + + auto unnest = make_uniq(optimizer.binder.GenerateTableIndex()); + unnest->expressions = std::move(unnest_exprs); + unnest->children.push_back(std::move(op)); + unnest->ResolveOperatorTypes(); + + return unique_ptr(std::move(unnest)); +} + +void TopNWindowElimination::AddStructExtractExprs( + vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const { + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + auto &struct_extract_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_extract"); + const auto struct_extract_fun = + struct_extract_entry.functions.GetFunctionByArguments(context, {struct_type, LogicalType::VARCHAR}); + + const auto &child_types = StructType::GetChildTypes(struct_type); + for (idx_t i = 0; i < child_types.size(); i++) { + const auto &alias = child_types[i].first; + + vector> fun_args(2); + fun_args[0] = aggregate_column_ref->Copy(); + fun_args[1] = make_uniq(alias); + + auto bound_function = function_binder.BindScalarFunction(struct_extract_fun, std::move(fun_args)); + bound_function->alias = alias; + exprs.push_back(std::move(bound_function)); + } +} + +unique_ptr +TopNWindowElimination::CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const { + const auto aggregate_type = GetAggregateType(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + const auto op_column_bindings = op->GetColumnBindings(); + + vector> proj_exprs; + // Only project necessary group columns + for (const auto &group_idx : group_idxs) { + proj_exprs.push_back( + make_uniq(op->types[group_idx.second], op_column_bindings[group_idx.second])); + } + + auto aggregate_column_ref = + make_uniq(aggregate_type, ColumnBinding(aggregate_table_idx, 0)); + + if (params.payload_type == TopNPayloadType::STRUCT_PACK) { + AddStructExtractExprs(proj_exprs, aggregate_type, aggregate_column_ref); + } else { + // No need for struct_unpack! Just reference the aggregate column + proj_exprs.push_back(std::move(aggregate_column_ref)); + } + + if (params.include_row_number) { + // If aggregate (i.e., limit 1): constant, if unnest: expect there to be a second column + if (op->type == LogicalOperatorType::LOGICAL_UNNEST) { + auto row_number_column_binding = GetRowNumberColumnBinding(op); + proj_exprs.push_back( + make_uniq("row_number", LogicalType::BIGINT, row_number_column_binding)); + } else { + proj_exprs.push_back(make_uniq(Value::BIGINT(1))); + } + } + + auto logical_projection = + make_uniq(optimizer.binder.GenerateTableIndex(), std::move(proj_exprs)); + logical_projection->children.push_back(std::move(op)); + logical_projection->ResolveOperatorTypes(); + + return unique_ptr(std::move(logical_projection)); +} + +bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { + if (op.type != LogicalOperatorType::LOGICAL_FILTER) { + return false; + } + + const auto &filter = op.Cast(); + if (filter.expressions.size() != 1) { + return false; + } + + if (filter.expressions[0]->type != ExpressionType::COMPARE_LESSTHANOREQUALTO) { + return false; + } + + auto &filter_comparison = filter.expressions[0]->Cast(); + if (filter_comparison.right->type != ExpressionType::VALUE_CONSTANT) { + return false; + } + auto &filter_value = filter_comparison.right->Cast(); + if (filter_value.value.type() != LogicalType::BIGINT) { + return false; + } + if (filter_value.value.GetValue() < 1) { + return false; + } + + if (filter_comparison.left->type != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + VisitExpression(&filter_comparison.left); + + reference child = *filter.children[0]; + while (child.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = child.get().Cast(); + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + + const auto current_column_ref = column_references.begin()->first; + column_references.clear(); + D_ASSERT(current_column_ref.table_index == projection.table_index); + VisitExpression(&projection.expressions[current_column_ref.column_index]); + + child = *child.get().children[0]; + } + + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + const auto filter_col_idx = column_references.begin()->first.table_index; + column_references.clear(); + + if (child.get().type != LogicalOperatorType::LOGICAL_WINDOW) { + return false; + } + const auto &window = child.get().Cast(); + if (window.window_index != filter_col_idx) { + return false; + } + if (window.expressions.size() != 1) { + for (idx_t i = 1; i < window.expressions.size(); ++i) { + if (!window.expressions[i]->Equals(*window.expressions[0])) { + return false; + } + } + } + if (window.expressions[0]->type != ExpressionType::WINDOW_ROW_NUMBER) { + return false; + } + auto &window_expr = window.expressions[0]->Cast(); + + if (window_expr.orders.size() != 1) { + return false; + } + if (window_expr.orders[0].type != OrderType::DESCENDING && window_expr.orders[0].type != OrderType::ASCENDING) { + return false; + } + if (window_expr.orders[0].null_order != OrderByNullType::NULLS_LAST) { + return false; + } + + // We have found a grouped top-n window construct! + return true; +} + +vector> TopNWindowElimination::GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, + map &group_idxs) { + vector> aggregate_args; + aggregate_args.reserve(bindings.size()); + + window.children[0]->ResolveOperatorTypes(); + const auto &window_child_types = window.children[0]->types; + const auto window_child_bindings = window.children[0]->GetColumnBindings(); + auto &window_expr = window.expressions[0]->Cast(); + + // Remember order of group columns to recreate that order in new bindings later + column_binding_map_t group_bindings; + for (idx_t i = 0; i < window_expr.partitions.size(); i++) { + auto &expr = window_expr.partitions[i]; + VisitExpression(&expr); + group_bindings[column_references.begin()->first] = i; + column_references.clear(); + } + + for (idx_t i = 0; i < bindings.size(); i++) { + const auto &binding = bindings[i]; + const auto group_binding = group_bindings.find(binding); + if (group_binding != group_bindings.end()) { + group_idxs[i] = group_binding->second; + continue; + } + if (binding.table_index == window.window_index) { + continue; + } + + auto column_id = binding.ToString(); + if (window.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // The column index points to the correct column binding + aggregate_args.push_back( + make_uniq(column_id, window_child_types[binding.column_index], binding)); + } else { + // The child operator could have multiple or no table indexes. Therefore, we must find the right type first + const auto child_column_idx = + static_cast(std::find(window_child_bindings.begin(), window_child_bindings.end(), binding) - + window_child_bindings.begin()); + aggregate_args.push_back( + make_uniq(column_id, window_child_types[child_column_idx], binding)); + } + } + + if (aggregate_args.size() == 1) { + // If we only project the aggregate value itself, we do not need it as an arg + VisitExpression(&window_expr.orders[0].expression); + const auto aggregate_value_binding = column_references.begin()->first; + column_references.clear(); + + if (window_expr.orders[0].expression->type == ExpressionType::BOUND_COLUMN_REF && + aggregate_args[0]->Cast().binding == aggregate_value_binding) { + return {}; + } + } + + return aggregate_args; +} + +vector TopNWindowElimination::TraverseProjectionBindings(const std::vector &old_bindings, + reference &op) { + auto new_bindings = old_bindings; + + // Traverse child projections to retrieve projections on window output + while (op.get().type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.get().Cast(); + + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &new_binding = new_bindings[i]; + D_ASSERT(new_binding.table_index == projection.table_index); + VisitExpression(&projection.expressions[new_binding.column_index]); + new_binding = column_references.begin()->first; + column_references.clear(); + } + op = *op.get().children[0]; + } + + return new_bindings; +} + +void TopNWindowElimination::UpdateTopmostBindings(const idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, + ColumnBindingReplacer &replacer) { + // The top-most operator's column order is [group][aggregate args][row number]. Now, set the new resulting bindings. + D_ASSERT(topmost_bindings.size() == new_bindings.size()); + replacer.replacement_bindings.reserve(new_bindings.size()); + set row_id_binding_idxs; + + const idx_t group_table_idx = GetGroupIdx(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + + // Project the group columns + idx_t current_column_idx = 0; + for (auto group_idx : group_idxs) { + const idx_t group_referencing_idx = group_idx.first; + new_bindings[group_referencing_idx].table_index = group_table_idx; + new_bindings[group_referencing_idx].column_index = group_idx.second; + replacer.replacement_bindings.emplace_back(topmost_bindings[group_referencing_idx], + new_bindings[group_referencing_idx]); + current_column_idx++; + } + + if (group_table_idx != aggregate_table_idx) { + // If the topmost operator is an aggregate, the table indexes are different, and we start back from 0 + current_column_idx = 0; + } + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + // We do not have an aggregate index, so we need to set an offset to hit the correct columns + current_column_idx = TraverseAndFindAggregateOffset(op->children[1]); + } + + // Project the args/value + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &binding = new_bindings[i]; + if (group_idxs.find(i) != group_idxs.end()) { + continue; + } + if (binding.table_index == window_idx) { + row_id_binding_idxs.insert(i); + continue; + } + binding.column_index = current_column_idx++; + binding.table_index = aggregate_table_idx; + replacer.replacement_bindings.emplace_back(topmost_bindings[i], binding); + } + + // Project the row number + for (const auto row_id_binding_idx : row_id_binding_idxs) { + // Let all projections on row id point to the last output column + auto &binding = new_bindings[row_id_binding_idx]; + binding = GetRowNumberColumnBinding(op); + replacer.replacement_bindings.emplace_back(topmost_bindings[row_id_binding_idx], binding); + } +} + +TopNWindowEliminationParameters +TopNWindowElimination::ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + vector> &aggregate_payload) { + TopNWindowEliminationParameters params; + + auto &limit_expr = filter.expressions[0]->Cast().right; + params.limit = limit_expr->Cast().value.GetValue(); + params.include_row_number = BindingsReferenceRowNumber(bindings, window); + params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; + auto &window_expr = window.expressions[0]->Cast(); + params.order_type = window_expr.orders[0].type; + + VisitExpression(&window_expr.orders[0].expression); + if (params.payload_type == TopNPayloadType::SINGLE_COLUMN && !aggregate_payload.empty()) { + VisitExpression(&aggregate_payload[0]); + } + for (const auto &column_ref : column_references) { + const auto &column_stats = stats->find(column_ref.first); + if (column_stats == stats->end() || column_stats->second->CanHaveNull()) { + params.can_be_null = true; + } + } + column_references.clear(); + + return params; +} + +bool TopNWindowElimination::CanUseLateMaterialization(const LogicalWindow &window, vector> &args, + vector &lhs_projections, + vector> &stack) { + auto &window_expr = window.expressions[0]->Cast(); + vector projections(window_expr.partitions.size() + args.size()); + + // Build a projection list for an LHS table scan to recreate the column order of an aggregate with struct packing + for (idx_t i = 0; i < window_expr.partitions.size(); i++) { + auto &partition = window_expr.partitions[i]; + VisitExpression(&partition); + projections[i] = column_references.begin()->first; + column_references.clear(); + } + for (idx_t i = 0; i < args.size(); i++) { + auto &arg = args[i]; + VisitExpression(&arg); + projections[window_expr.partitions.size() + i] = column_references.begin()->first; + column_references.clear(); + } + + reference op = *window.children[0]; + + // Traverse projections to a single table scan + while (!op.get().children.empty()) { + stack.push_back(op); + switch (op.get().type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &projection = op.get().Cast(); + for (idx_t i = 0; i < projections.size(); i++) { + D_ASSERT(projection.table_index == projections[i].table_index); + const idx_t projection_idx = projections[i].column_index; + VisitExpression(&projection.expressions[projection_idx]); + projections[i] = column_references.begin()->first; + column_references.clear(); + } + op = *op.get().children[0]; + break; + } + case LogicalOperatorType::LOGICAL_FILTER: { + op = *op.get().children[0]; + break; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op.get().Cast(); + if (join.join_type != JoinType::INNER && join.join_type != JoinType::SEMI && + join.join_type != JoinType::ANTI) { + return false; + } + + // If there is a join, we only allow late materialization if the projected output stems from a single table. + // However, we allow replacing references to join columns as they are equal to the other side by condition. + column_binding_map_t replaceable_bindings; + for (auto &condition : join.conditions) { + if (condition.comparison != ExpressionType::COMPARE_EQUAL) { + return false; + } + VisitExpression(&condition.left); + auto left_binding = column_references.begin()->first; + column_references.clear(); + VisitExpression(&condition.right); + auto right_binding = column_references.begin()->first; + column_references.clear(); + + replaceable_bindings[left_binding] = right_binding; + replaceable_bindings[right_binding] = left_binding; + } + + auto left_column_bindings = join.children[0]->GetColumnBindings(); + auto right_column_bindings = join.children[1]->GetColumnBindings(); + auto lidxs = join.children[0]->GetTableIndex(); + auto ridxs = join.children[1]->GetTableIndex(); + if (lidxs.size() != 1 || ridxs.size() != 1) { + return false; + } + auto left_idx = lidxs[0]; + auto right_idx = ridxs[0]; + + bool all_left_replaceable = true; + bool all_right_replaceable = true; + for (idx_t i = 0; i < projections.size(); i++) { + const auto &projection = projections[i]; + auto &column_binding = projection.table_index == left_idx + ? left_column_bindings[projection.column_index] + : right_column_bindings[projection.column_index]; + if (replaceable_bindings.find(column_binding) == replaceable_bindings.end()) { + if (column_binding.table_index == left_idx) { + all_left_replaceable = false; + } else { + all_right_replaceable = false; + } + } + } + + if (!all_left_replaceable && !all_right_replaceable) { + // We cannot use late materialization by scanning a single table. + return false; + } + + idx_t replace_table_idx = all_right_replaceable ? right_idx : left_idx; + for (idx_t i = 0; i < projections.size(); i++) { + const auto projection_idx = projections[i]; + auto &column_binding = projection_idx.table_index == left_idx + ? left_column_bindings[projection_idx.column_index] + : right_column_bindings[projection_idx.column_index]; + if (column_binding.table_index == replace_table_idx) { + projections[i] = replaceable_bindings[column_binding]; + } + } + + if (all_right_replaceable) { + op = *op.get().children[0]; + } else { + op = *op.get().children[1]; + } + break; + } + default: { + return false; + } + } + } + stack.push_back(op); + + D_ASSERT(op.get().type == LogicalOperatorType::LOGICAL_GET); + auto &logical_get = op.get().Cast(); + if (!logical_get.function.late_materialization || !logical_get.function.get_row_id_columns) { + return false; + } + + const auto rowid_column_idxs = logical_get.function.get_row_id_columns(context, logical_get.bind_data.get()); + if (rowid_column_idxs.size() > 1) { + // TODO: support multi-column rowids for parquet + return false; + } + for (const auto &col_idx : rowid_column_idxs) { + auto entry = logical_get.virtual_columns.find(col_idx); + if (entry == logical_get.virtual_columns.end()) { + return false; + } + } + // Check if we need the projection map + for (idx_t i = 0; i < projections.size(); i++) { + if (projections[i].column_index != i) { + for (auto &proj : projections) { + lhs_projections.push_back(proj.column_index); + } + break; + } + } + return true; +} + +unique_ptr TopNWindowElimination::TryPrepareLateMaterialization(const LogicalWindow &window, + vector> &args) { + vector lhs_projections; + vector> stack; + bool use_late_materialization = CanUseLateMaterialization(window, args, lhs_projections, stack); + if (!use_late_materialization) { + return nullptr; + } + + D_ASSERT(stack.back().get().type == LogicalOperatorType::LOGICAL_GET); + auto &rhs_get = stack.back().get().Cast(); + auto lhs = ConstructLHS(rhs_get, lhs_projections); + + const auto rhs_rowid_column_idxs = rhs_get.function.get_row_id_columns(context, rhs_get.bind_data.get()); + vector rhs_rowid_columns; + for (const auto &col_idx : rhs_rowid_column_idxs) { + rhs_rowid_columns.push_back(rhs_get.virtual_columns[col_idx]); + } + const auto rhs_rowid_idxs = + LateMaterializationHelper::GetOrInsertRowIds(rhs_get, rhs_rowid_column_idxs, rhs_rowid_columns); + + // Add rowid column to the operators on the right-hand side + idx_t last_table_idx = rhs_get.table_index; + idx_t last_rowid_offset = rhs_rowid_idxs[0]; + + // Add rowid projections to the query tree on the right-hand side + for (auto stack_it = std::next(stack.rbegin()); stack_it != stack.rend(); ++stack_it) { + auto &op = stack_it->get(); + + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &rowid_column = rhs_rowid_columns[0]; + op.expressions.push_back(make_uniq( + rowid_column.name, rowid_column.type, ColumnBinding {last_table_idx, last_rowid_offset})); + last_table_idx = op.GetTableIndex()[0]; + last_rowid_offset = op.expressions.size() - 1; + break; + } + case LogicalOperatorType::LOGICAL_FILTER: { + if (op.HasProjectionMap()) { + auto &filter = op.Cast(); + filter.projection_map.push_back(last_rowid_offset); + } + break; + } + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + if (op.HasProjectionMap()) { + auto &join = op.Cast(); + auto &op_child = std::prev(stack_it)->get(); + if (&op_child == &*join.children[0]) { + join.left_projection_map.push_back(last_rowid_offset); + } else { + join.right_projection_map.push_back(last_rowid_offset); + } + } + break; + } + default: + throw InternalException("Unsupported operator in late materialization right-hand side."); + } + } + + // Change args to project rowid + args.clear(); + args.push_back(make_uniq(rhs_rowid_columns[0].name, rhs_rowid_columns[0].type, + ColumnBinding {last_table_idx, last_rowid_offset})); + + return lhs; +} + +unique_ptr TopNWindowElimination::ConstructLHS(LogicalGet &rhs, vector &projections) const { + auto lhs_get = LateMaterializationHelper::CreateLHSGet(rhs, optimizer.binder); + const auto lhs_rowid_column_idxs = lhs_get->function.get_row_id_columns(context, lhs_get->bind_data.get()); + vector lhs_rowid_columns; + for (const auto &col_idx : lhs_rowid_column_idxs) { + lhs_rowid_columns.push_back(rhs.virtual_columns[col_idx]); + } + + const auto lhs_rowid_idxs = + LateMaterializationHelper::GetOrInsertRowIds(*lhs_get, lhs_rowid_column_idxs, lhs_rowid_columns); + + if (!projections.empty()) { + for (auto rowid_idx : lhs_rowid_idxs) { + projections.push_back(rowid_idx); + } + lhs_get->ResolveOperatorTypes(); + + vector> projs; + projs.reserve(projections.size()); + for (auto projection_id : projections) { + projs.push_back(make_uniq(lhs_get->types[projection_id], + ColumnBinding {lhs_get->table_index, projection_id})); + } + auto projection = make_uniq(optimizer.binder.GenerateTableIndex(), std::move(projs)); + projection->children.push_back(std::move(lhs_get)); + return unique_ptr(std::move(projection)); + } + return unique_ptr(std::move(lhs_get)); +} + +unique_ptr TopNWindowElimination::ConstructJoin(unique_ptr lhs, + unique_ptr rhs, + const idx_t aggregate_offset, + const TopNWindowEliminationParameters ¶ms) { + auto join = make_uniq(JoinType::SEMI); + + JoinCondition condition; + condition.comparison = ExpressionType::COMPARE_EQUAL; + + lhs->ResolveOperatorTypes(); + const auto lhs_rowid_idx = lhs->types.size() - 1; + const auto rhs_rowid_idx = rhs->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY ? 0 : aggregate_offset; + + condition.left = make_uniq("rowid", lhs->types[lhs_rowid_idx], + ColumnBinding {lhs->GetTableIndex()[0], lhs_rowid_idx}); + condition.right = make_uniq("rowid", rhs->types[aggregate_offset], + ColumnBinding {GetAggregateIdx(rhs), rhs_rowid_idx}); + + join->conditions.push_back(std::move(condition)); + if (params.include_row_number) { + // Add row_number to join result + join->join_type = JoinType::INNER; + join->right_projection_map.push_back(rhs->types.size() - 1); + } + + join->children.push_back(std::move(lhs)); + join->children.push_back(std::move(rhs)); + + return unique_ptr(std::move(join)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/unnest_rewriter.cpp b/src/duckdb/src/optimizer/unnest_rewriter.cpp index 4c4207e2a..353c1401e 100644 --- a/src/duckdb/src/optimizer/unnest_rewriter.cpp +++ b/src/duckdb/src/optimizer/unnest_rewriter.cpp @@ -1,13 +1,19 @@ #include "duckdb/optimizer/unnest_rewriter.hpp" -#include "duckdb/common/pair.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/column_binding.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/logical_operator.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_delim_get.hpp" +#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_unnest.hpp" -#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/planner/expression_binder.hpp" namespace duckdb { @@ -33,14 +39,12 @@ void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expressi } unique_ptr UnnestRewriter::Optimize(unique_ptr op) { - UnnestRewriterPlanUpdater updater; vector>> candidates; - FindCandidates(op, candidates); + FindCandidates(op, op, candidates); // rewrite the plan and update the bindings for (auto &candidate : candidates) { - // rearrange the logical operators if (RewriteCandidate(candidate)) { updater.overwritten_tbl_idx = overwritten_tbl_idx; @@ -57,11 +61,11 @@ unique_ptr UnnestRewriter::Optimize(unique_ptr return op; } -void UnnestRewriter::FindCandidates(unique_ptr &op, +void UnnestRewriter::FindCandidates(unique_ptr &root, unique_ptr &op, vector>> &candidates) { // search children before adding, so that we add candidates bottom-up for (auto &child : op->children) { - FindCandidates(child, candidates); + FindCandidates(root, child, candidates); } // search for operator that has a LOGICAL_DELIM_JOIN as its child @@ -99,14 +103,69 @@ void UnnestRewriter::FindCandidates(unique_ptr &op, curr_op = &curr_op->get()->children[0]; } + // pattern1: delim_get -> unnest-> projection if (curr_op->get()->type == LogicalOperatorType::LOGICAL_UNNEST && curr_op->get()->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { candidates.push_back(op); + return; + } + + curr_op = &delim_join.children[other_idx]; + if (curr_op->get()->type == LogicalOperatorType::LOGICAL_GET) { + auto &get = curr_op->get()->Cast(); + if (!ExpressionBinder::IsUnnestFunction(get.function.name)) { + return; + } + // pattern2: delim_get -> projection -> table_in_out(unnest) + auto &unnest_get_ref = curr_op->get()->Cast(); + if (unnest_get_ref.ordinality_idx.IsValid()) { + // we also unnest delim_index so cannot rewrite it + return; + } + curr_op = &curr_op->get()->children[0]; + + // find pattern2 and convert to pattern1 + if (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION && + curr_op->get()->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { + auto unnest_get = std::move(delim_join.children[other_idx]); + unnest_get->ResolveOperatorTypes(); + ColumnBindingReplacer replacer; + auto unnest_get_column = unnest_get->GetColumnBindings(); + auto &proj = curr_op->get()->Cast(); + auto delim_get = std::move(proj.children[0]); + auto unnest = make_uniq(unnest_get->GetTableIndex()[0]); + unnest->children.push_back(std::move(delim_get)); + delim_join.children[other_idx] = std::move(*curr_op); + for (idx_t i = 0; i < unnest_get_column.size(); i++) { + auto &col_bind = unnest_get_column[i]; + D_ASSERT(col_bind.table_index == unnest_get->GetTableIndex()[0] || + col_bind.table_index == proj.table_index); + if (col_bind.table_index == unnest_get->GetTableIndex()[0]) { + D_ASSERT(proj.expressions[col_bind.column_index]->GetExpressionClass() == + ExpressionClass::BOUND_COLUMN_REF); + auto &bind_col = proj.expressions[col_bind.column_index]->Cast(); + auto unnest_expr = make_uniq(unnest_get->types[i]); + unnest_expr->child = proj.expressions[col_bind.column_index]->Copy(); + bind_col.binding = ColumnBinding(unnest->GetTableIndex()[0], bind_col.binding.column_index); + unnest->expressions.push_back(std::move(unnest_expr)); + auto new_column_ref = ColumnBinding(bind_col.binding.table_index, unnest->expressions.size() - 1); + auto unnest_ref = make_uniq(bind_col.alias, unnest_get->types[i], + new_column_ref, bind_col.depth); + proj.expressions[col_bind.column_index] = std::move(unnest_ref); + proj.types[col_bind.column_index] = unnest_get->types[i]; + replacer.replacement_bindings.push_back(ReplacementBinding( + col_bind, ColumnBinding(proj.table_index, col_bind.column_index), unnest_get->types[i])); + } + } + proj.children[0] = std::move(unnest); + replacer.stop_operator = proj; + replacer.VisitOperator(*root); + candidates.push_back(op); + } } } bool UnnestRewriter::RewriteCandidate(unique_ptr &candidate) { - auto &topmost_op = *candidate; if (topmost_op.type != LogicalOperatorType::LOGICAL_PROJECTION && topmost_op.type != LogicalOperatorType::LOGICAL_WINDOW && @@ -160,14 +219,12 @@ bool UnnestRewriter::RewriteCandidate(unique_ptr &candidate) { void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique_ptr &candidate, UnnestRewriterPlanUpdater &updater) { - auto &topmost_op = *candidate; idx_t shift = lhs_bindings.size(); vector *> path_to_unnest; auto curr_op = &topmost_op.children[0]; while (curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION) { - path_to_unnest.push_back(curr_op); D_ASSERT(curr_op->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); auto &proj = curr_op->get()->Cast(); @@ -222,7 +279,6 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique // add the LHS expressions to each LOGICAL_PROJECTION for (idx_t i = path_to_unnest.size(); i > 0; i--) { - D_ASSERT(path_to_unnest[i - 1]->get()->type == LogicalOperatorType::LOGICAL_PROJECTION); auto &proj = path_to_unnest[i - 1]->get()->Cast(); @@ -254,7 +310,6 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &updater, unique_ptr &candidate) { - auto &topmost_op = *candidate; // traverse LOGICAL_PROJECTION(s) @@ -296,7 +351,6 @@ void UnnestRewriter::UpdateBoundUnnestBindings(UnnestRewriterPlanUpdater &update } void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { - D_ASSERT(op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); auto &delim_join = op.Cast(); for (idx_t i = 0; i < delim_join.duplicate_eliminated_columns.size(); i++) { @@ -308,7 +362,6 @@ void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { } void UnnestRewriter::GetLHSExpressions(LogicalOperator &op) { - op.ResolveOperatorTypes(); auto col_bindings = op.GetColumnBindings(); D_ASSERT(op.types.size() == col_bindings.size()); diff --git a/src/duckdb/src/parallel/async_result.cpp b/src/duckdb/src/parallel/async_result.cpp new file mode 100644 index 000000000..a32086b84 --- /dev/null +++ b/src/duckdb/src/parallel/async_result.cpp @@ -0,0 +1,192 @@ +#include "duckdb/parallel/executor_task.hpp" +#include "duckdb/parallel/async_result.hpp" +#include "duckdb/parallel/interrupt.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/physical_table_scan_enum.hpp" + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE +#include "duckdb/parallel/sleep_async_task.hpp" +#endif + +namespace duckdb { + +struct Counter { + explicit Counter(idx_t size) : counter(size) { + } + bool IterateAndCheckCounter() { + D_ASSERT(counter.load() > 0); + idx_t post_decreast = --counter; + return (post_decreast == 0); + } + +private: + atomic counter; +}; + +class AsyncExecutionTask : public ExecutorTask { +public: + AsyncExecutionTask(Executor &executor, unique_ptr &&async_task, InterruptState &interrupt_state, + shared_ptr counter) + : ExecutorTask(executor, nullptr), async_task(std::move(async_task)), interrupt_state(interrupt_state), + counter(std::move(counter)) { + } + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + async_task->Execute(); + if (counter->IterateAndCheckCounter()) { + interrupt_state.Callback(); + } + return TaskExecutionResult::TASK_FINISHED; + } + + string TaskType() const override { + return "AsyncTask"; + } + +private: + unique_ptr async_task; + InterruptState interrupt_state; + shared_ptr counter; +}; + +AsyncResult::AsyncResult(SourceResultType t) : AsyncResult(GetAsyncResultType(t)) { +} + +AsyncResult::AsyncResult(AsyncResultType t) : result_type(t) { + if (result_type == AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult constructed with a BLOCKED state, do provide AsyncTasks"); + } +} + +AsyncResult::AsyncResult(vector> &&tasks) + : result_type(AsyncResultType::BLOCKED), async_tasks(std::move(tasks)) { + if (async_tasks.empty()) { + throw InternalException("AsyncResult constructed from empty vector of tasks"); + } +} + +AsyncResult &AsyncResult::operator=(duckdb::SourceResultType t) { + return operator=(AsyncResult(t)); +} + +AsyncResult &AsyncResult::operator=(duckdb::AsyncResultType t) { + return operator=(AsyncResult(t)); +} + +AsyncResult &AsyncResult::operator=(AsyncResult &&other) noexcept { + result_type = other.result_type; + async_tasks = std::move(other.async_tasks); + return *this; +} + +void AsyncResult::ScheduleTasks(InterruptState &interrupt_state, Executor &executor) { + if (result_type != AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult::ScheduleTasks called on non BLOCKED AsyncResult"); + } + + if (async_tasks.empty()) { + throw InternalException("AsyncResult::ScheduleTasks called with no available tasks"); + } + + shared_ptr counter = make_shared_ptr(async_tasks.size()); + + for (auto &async_task : async_tasks) { + auto task = make_uniq(executor, std::move(async_task), interrupt_state, counter); + TaskScheduler::GetScheduler(executor.context).ScheduleTask(executor.GetToken(), std::move(task)); + } +} + +void AsyncResult::ExecuteTasksSynchronously() { + if (result_type != AsyncResultType::BLOCKED) { + throw InternalException("AsyncResult::ExecuteTasksSynchronously called on non BLOCKED AsyncResult"); + } + + if (async_tasks.empty()) { + throw InternalException("AsyncResult::ExecuteTasksSynchronously called with no available tasks"); + } + + for (auto &async_task : async_tasks) { + async_task->Execute(); + } + + async_tasks.clear(); + + result_type = AsyncResultType::HAVE_MORE_OUTPUT; +} + +AsyncResultType AsyncResult::GetAsyncResultType(SourceResultType s) { + switch (s) { + case SourceResultType::HAVE_MORE_OUTPUT: + return AsyncResultType::HAVE_MORE_OUTPUT; + case SourceResultType::FINISHED: + return AsyncResultType::FINISHED; + case SourceResultType::BLOCKED: + return AsyncResultType::BLOCKED; + } + throw InternalException("GetAsyncResultType has an unexpected input"); +} + +bool AsyncResult::HasTasks() const { + D_ASSERT(result_type != AsyncResultType::INVALID); + if (async_tasks.empty()) { + D_ASSERT(result_type != AsyncResultType::BLOCKED); + return false; + } else { + D_ASSERT(result_type == AsyncResultType::BLOCKED); + return true; + } +} +AsyncResultType AsyncResult::GetResultType() const { + D_ASSERT(result_type != AsyncResultType::INVALID); + if (async_tasks.empty()) { + D_ASSERT(result_type != AsyncResultType::BLOCKED); + } else { + D_ASSERT(result_type == AsyncResultType::BLOCKED); + } + return result_type; +} +vector> &&AsyncResult::ExtractAsyncTasks() { + D_ASSERT(result_type != AsyncResultType::INVALID); + result_type = AsyncResultType::INVALID; + return std::move(async_tasks); +} + +#ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE +vector> AsyncResult::GenerateTestTasks() { + vector> tasks; + auto random_number = rand() % 16; + switch (random_number) { + case 0: + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); + tasks.push_back(make_uniq(rand() % 32)); +#ifndef AVOID_DUCKDB_DEBUG_ASYNC_THROW + case 1: + tasks.push_back(make_uniq(rand() % 32)); +#endif + default: + break; + } + return tasks; +} +#endif + +AsyncResultsExecutionMode +AsyncResult::ConvertToAsyncResultExecutionMode(const PhysicalTableScanExecutionStrategy &execution_mode) { + switch (execution_mode) { + case PhysicalTableScanExecutionStrategy::DEFAULT: + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR: + case PhysicalTableScanExecutionStrategy::TASK_EXECUTOR_BUT_FORCE_SYNC_CHECKS: + return AsyncResultsExecutionMode::TASK_EXECUTOR; + case PhysicalTableScanExecutionStrategy::SYNCHRONOUS: + return AsyncResultsExecutionMode::SYNCHRONOUS; + } + throw InternalException("ConvertToAsyncResultExecutionMode passed an unexpected execution_mode"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parallel/executor.cpp b/src/duckdb/src/parallel/executor.cpp index d79fa1816..9a9cf4703 100644 --- a/src/duckdb/src/parallel/executor.cpp +++ b/src/duckdb/src/parallel/executor.cpp @@ -379,7 +379,6 @@ void Executor::Initialize(PhysicalOperator &plan) { } void Executor::InitializeInternal(PhysicalOperator &plan) { - auto &scheduler = TaskScheduler::GetScheduler(context); { lock_guard elock(executor_lock); @@ -423,7 +422,6 @@ void Executor::InitializeInternal(PhysicalOperator &plan) { void Executor::CancelTasks() { task.reset(); - { lock_guard elock(executor_lock); // mark the query as cancelled so tasks will early-out @@ -463,17 +461,23 @@ void Executor::SignalTaskRescheduled(lock_guard &) { void Executor::WaitForTask() { #ifndef DUCKDB_NO_THREADS - static constexpr std::chrono::milliseconds WAIT_TIME_MS = std::chrono::milliseconds(WAIT_TIME); + static constexpr std::chrono::microseconds WAIT_TIME_MS = std::chrono::microseconds(WAIT_TIME * 1000); + auto begin = std::chrono::high_resolution_clock::now(); std::unique_lock l(executor_lock); + auto end = std::chrono::high_resolution_clock::now(); + auto dur = end - begin; + auto ms = NumericCast(std::chrono::duration_cast(dur).count()); if (to_be_rescheduled_tasks.empty()) { + blocked_thread_time += ms; return; } if (ResultCollectorIsBlocked()) { // If the result collector is blocked, it won't get unblocked until the connection calls Fetch + blocked_thread_time += ms; return; } - blocked_thread_time++; + blocked_thread_time += ms + WAIT_TIME_MS.count(); task_reschedule.wait_for(l, WAIT_TIME_MS); #endif } @@ -578,6 +582,12 @@ PendingExecutionResult Executor::ExecuteTask(bool dry_run) { } else if (result == TaskExecutionResult::TASK_FINISHED) { // if the task is finished, clean it up task.reset(); + } else if (result == TaskExecutionResult::TASK_ERROR) { + if (!HasError()) { + // This is very much unexpected, TASK_ERROR means this executor should have an Error + throw InternalException("A task executed within Executor::ExecuteTask, from own producer, returned " + "TASK_ERROR without setting error on the Executor"); + } } } if (!HasError()) { @@ -672,13 +682,12 @@ void Executor::ThrowException() { } void Executor::Flush(ThreadContext &thread_context) { - static constexpr std::chrono::milliseconds WAIT_TIME_MS = std::chrono::milliseconds(WAIT_TIME); auto global_profiler = profiler; if (global_profiler) { global_profiler->Flush(thread_context.profiler); auto blocked_time = blocked_thread_time.load(); - global_profiler->SetInfo(double(blocked_time * WAIT_TIME_MS.count()) / 1000); + global_profiler->SetBlockedTime(double(blocked_time) / 1000.0 / 1000.0); } } diff --git a/src/duckdb/src/parallel/pipeline.cpp b/src/duckdb/src/parallel/pipeline.cpp index bc511539a..1e5a20b4a 100644 --- a/src/duckdb/src/parallel/pipeline.cpp +++ b/src/duckdb/src/parallel/pipeline.cpp @@ -106,12 +106,16 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { if (!source->ParallelSource()) { return false; } + auto max_threads = source_state->MaxThreads(); + for (auto &op_ref : operators) { auto &op = op_ref.get(); if (!op.ParallelOperator()) { return false; } + max_threads = MinValue(max_threads, op.op_state->MaxThreads(max_threads)); } + auto partition_info = sink->RequiredPartitionInfo(); if (partition_info.batch_index) { if (!source->SupportsPartitioning(OperatorPartitionInfo::BatchIndex())) { @@ -119,7 +123,7 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); } } - auto max_threads = source_state->MaxThreads(); + auto &scheduler = TaskScheduler::GetScheduler(executor.context); auto active_threads = NumericCast(scheduler.NumberOfThreads()); if (max_threads > active_threads) { diff --git a/src/duckdb/src/parallel/pipeline_executor.cpp b/src/duckdb/src/parallel/pipeline_executor.cpp index 9db69ac99..274dc8b99 100644 --- a/src/duckdb/src/parallel/pipeline_executor.cpp +++ b/src/duckdb/src/parallel/pipeline_executor.cpp @@ -123,12 +123,15 @@ bool PipelineExecutor::TryFlushCachingOperators(ExecutionBudget &chunk_budget) { return true; } -SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk) { +SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk, const bool have_more_output) { D_ASSERT(required_partition_info.AnyRequired()); auto max_batch_index = pipeline.base_batch_index + PipelineBuildState::BATCH_INCREMENT - 1; // by default set it to the maximum valid batch index value for the current pipeline + auto &partition_info = local_sink_state->partition_info; OperatorPartitionData next_data(max_batch_index); - if (source_chunk.size() > 0) { + if ((source_chunk.size() > 0)) { + D_ASSERT(local_source_state); + D_ASSERT(pipeline.source_state); // if we retrieved data - initialize the next batch index auto partition_data = pipeline.source->GetPartitionData(context, source_chunk, *pipeline.source_state, *local_source_state, required_partition_info); @@ -140,8 +143,9 @@ SinkNextBatchType PipelineExecutor::NextBatch(DataChunk &source_chunk) { throw InternalException("Pipeline batch index - invalid batch index %llu returned by source operator", batch_index); } + } else if (have_more_output) { + next_data.batch_index = partition_info.batch_index.GetIndex(); } - auto &partition_info = local_sink_state->partition_info; if (next_data.batch_index == partition_info.batch_index.GetIndex()) { // no changes, return return SinkNextBatchType::READY; @@ -193,7 +197,7 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { } OperatorResultType result; - if (exhausted_source && done_flushing && !remaining_sink_chunk && !next_batch_blocked && + if (exhausted_pipeline && done_flushing && !remaining_sink_chunk && !next_batch_blocked && in_process_operators.empty()) { break; } else if (remaining_sink_chunk) { @@ -206,8 +210,8 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { // the operators have to be called with the same input chunk to produce the rest of the output D_ASSERT(source_chunk.size() > 0); result = ExecutePushInternal(source_chunk, chunk_budget); - } else if (exhausted_source && !next_batch_blocked && !done_flushing) { - // The source was exhausted, try flushing all operators + } else if (exhausted_pipeline && !next_batch_blocked && !done_flushing) { + // The pipeline was exhausted, try flushing all operators auto flush_completed = TryFlushCachingOperators(chunk_budget); if (flush_completed) { done_flushing = true; @@ -220,8 +224,8 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { return PipelineExecuteResult::NOT_FINISHED; } } - } else if (!exhausted_source || next_batch_blocked) { - SourceResultType source_result; + } else if (!exhausted_pipeline || next_batch_blocked) { + SourceResultType source_result = SourceResultType::BLOCKED; if (!next_batch_blocked) { // "Regular" path: fetch a chunk from the source and push it through the pipeline source_chunk.Reset(); @@ -230,20 +234,19 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { return PipelineExecuteResult::INTERRUPTED; } if (source_result == SourceResultType::FINISHED) { - exhausted_source = true; + exhausted_pipeline = true; } } if (required_partition_info.AnyRequired()) { - auto next_batch_result = NextBatch(source_chunk); + auto next_batch_result = NextBatch(source_chunk, source_result == SourceResultType::HAVE_MORE_OUTPUT); next_batch_blocked = next_batch_result == SinkNextBatchType::BLOCKED; if (next_batch_blocked) { return PipelineExecuteResult::INTERRUPTED; } } - if (exhausted_source && source_chunk.size() == 0) { - // To ensure that we're not early-terminating the pipeline + if (exhausted_pipeline && source_chunk.size() == 0) { continue; } @@ -259,11 +262,12 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { } if (result == OperatorResultType::FINISHED) { + exhausted_pipeline = true; break; } } while (chunk_budget.Next()); - if ((!exhausted_source || !done_flushing) && !IsFinished()) { + if ((!exhausted_pipeline || !done_flushing) && !IsFinished()) { return PipelineExecuteResult::NOT_FINISHED; } diff --git a/src/duckdb/src/parallel/task_executor.cpp b/src/duckdb/src/parallel/task_executor.cpp index fa2c0087c..9487a1427 100644 --- a/src/duckdb/src/parallel/task_executor.cpp +++ b/src/duckdb/src/parallel/task_executor.cpp @@ -69,8 +69,10 @@ TaskExecutionResult BaseExecutorTask::Execute(TaskExecutionMode mode) { return TaskExecutionResult::TASK_FINISHED; } try { - TaskNotifier task_notifier {executor.context}; - ExecuteTask(); + { + TaskNotifier task_notifier {executor.context}; + ExecuteTask(); + } executor.FinishTask(); return TaskExecutionResult::TASK_FINISHED; } catch (std::exception &ex) { diff --git a/src/duckdb/src/parallel/task_scheduler.cpp b/src/duckdb/src/parallel/task_scheduler.cpp index 9d8f94b65..b5ed2db24 100644 --- a/src/duckdb/src/parallel/task_scheduler.cpp +++ b/src/duckdb/src/parallel/task_scheduler.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/numeric_utils.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/storage/block_allocator.hpp" #ifndef DUCKDB_NO_THREADS #include "concurrentqueue.h" #include "duckdb/common/thread.hpp" @@ -270,17 +271,19 @@ void TaskScheduler::ExecuteForever(atomic *marker) { #ifndef DUCKDB_NO_THREADS static constexpr const int64_t INITIAL_FLUSH_WAIT = 500000; // initial wait time of 0.5s (in mus) before flushing - auto &config = DBConfig::GetConfig(db); + const auto &block_allocator = BlockAllocator::Get(db); + const auto &config = DBConfig::GetConfig(db); + shared_ptr task; // loop until the marker is set to false while (*marker) { - if (!Allocator::SupportsFlush()) { + if (!block_allocator.SupportsFlush()) { // allocator can't flush, just start an untimed wait queue->semaphore.wait(); } else if (!queue->semaphore.wait(INITIAL_FLUSH_WAIT)) { // allocator can flush, we flush this threads outstanding allocations after it was idle for 0.5s - Allocator::ThreadFlush(allocator_background_threads, allocator_flush_threshold, - NumericCast(requested_thread_count.load())); + block_allocator.ThreadFlush(allocator_background_threads, allocator_flush_threshold, + NumericCast(requested_thread_count.load())); auto decay_delay = Allocator::DecayDelay(); if (!decay_delay.IsValid()) { // no decay delay specified - just wait @@ -322,8 +325,8 @@ void TaskScheduler::ExecuteForever(atomic *marker) { } } // this thread will exit, flush all of its outstanding allocations - if (Allocator::SupportsFlush()) { - Allocator::ThreadFlush(allocator_background_threads, 0, NumericCast(requested_thread_count.load())); + if (block_allocator.SupportsFlush()) { + block_allocator.ThreadFlush(allocator_background_threads, 0, NumericCast(requested_thread_count.load())); Allocator::ThreadIdle(); } #else @@ -563,9 +566,7 @@ void TaskScheduler::RelaunchThreadsInternal(int32_t n) { } } current_thread_count = NumericCast(threads.size() + config.options.external_threads); - if (Allocator::SupportsFlush()) { - Allocator::FlushAll(); - } + BlockAllocator::Get(db).FlushAll(); #endif } diff --git a/src/duckdb/src/parser/column_definition.cpp b/src/duckdb/src/parser/column_definition.cpp index ed19f7715..497eec6d2 100644 --- a/src/duckdb/src/parser/column_definition.cpp +++ b/src/duckdb/src/parser/column_definition.cpp @@ -1,8 +1,8 @@ #include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" #include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/constraint.cpp b/src/duckdb/src/parser/constraint.cpp index c06a5c800..67aaec651 100644 --- a/src/duckdb/src/parser/constraint.cpp +++ b/src/duckdb/src/parser/constraint.cpp @@ -1,7 +1,6 @@ #include "duckdb/parser/constraint.hpp" #include "duckdb/common/printer.hpp" -#include "duckdb/parser/constraints/list.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/constraints/not_null_constraint.cpp b/src/duckdb/src/parser/constraints/not_null_constraint.cpp index fb406e5a7..bb85fc36f 100644 --- a/src/duckdb/src/parser/constraints/not_null_constraint.cpp +++ b/src/duckdb/src/parser/constraints/not_null_constraint.cpp @@ -1,5 +1,7 @@ #include "duckdb/parser/constraints/not_null_constraint.hpp" +#include "duckdb/common/helper.hpp" + namespace duckdb { NotNullConstraint::NotNullConstraint(LogicalIndex index) : Constraint(ConstraintType::NOT_NULL), index(index) { diff --git a/src/duckdb/src/parser/constraints/unique_constraint.cpp b/src/duckdb/src/parser/constraints/unique_constraint.cpp index d3379be42..20bcb76a5 100644 --- a/src/duckdb/src/parser/constraints/unique_constraint.cpp +++ b/src/duckdb/src/parser/constraints/unique_constraint.cpp @@ -1,6 +1,7 @@ #include "duckdb/parser/constraints/unique_constraint.hpp" - #include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/enums/index_constraint_type.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/expression/case_expression.cpp b/src/duckdb/src/parser/expression/case_expression.cpp index fd76081ac..4b009b831 100644 --- a/src/duckdb/src/parser/expression/case_expression.cpp +++ b/src/duckdb/src/parser/expression/case_expression.cpp @@ -1,10 +1,6 @@ #include "duckdb/parser/expression/case_expression.hpp" - #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { CaseExpression::CaseExpression() : ParsedExpression(ExpressionType::CASE_EXPR, ExpressionClass::CASE) { diff --git a/src/duckdb/src/parser/expression/cast_expression.cpp b/src/duckdb/src/parser/expression/cast_expression.cpp index 758cc29ce..6cc43bc05 100644 --- a/src/duckdb/src/parser/expression/cast_expression.cpp +++ b/src/duckdb/src/parser/expression/cast_expression.cpp @@ -1,10 +1,6 @@ #include "duckdb/parser/expression/cast_expression.hpp" - #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { CastExpression::CastExpression(LogicalType target, unique_ptr child, bool try_cast_p) diff --git a/src/duckdb/src/parser/expression/collate_expression.cpp b/src/duckdb/src/parser/expression/collate_expression.cpp index 70b754f9c..2874c0033 100644 --- a/src/duckdb/src/parser/expression/collate_expression.cpp +++ b/src/duckdb/src/parser/expression/collate_expression.cpp @@ -1,10 +1,6 @@ #include "duckdb/parser/expression/collate_expression.hpp" - #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { CollateExpression::CollateExpression(string collation_p, unique_ptr child) diff --git a/src/duckdb/src/parser/expression/columnref_expression.cpp b/src/duckdb/src/parser/expression/columnref_expression.cpp index c1eb6e9cd..a5e5193e6 100644 --- a/src/duckdb/src/parser/expression/columnref_expression.cpp +++ b/src/duckdb/src/parser/expression/columnref_expression.cpp @@ -2,8 +2,8 @@ #include "duckdb/common/types/hash.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/parser/qualified_name.hpp" #include "duckdb/planner/binding_alias.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/expression/comparison_expression.cpp b/src/duckdb/src/parser/expression/comparison_expression.cpp index d0649b35a..15d930839 100644 --- a/src/duckdb/src/parser/expression/comparison_expression.cpp +++ b/src/duckdb/src/parser/expression/comparison_expression.cpp @@ -1,11 +1,5 @@ #include "duckdb/parser/expression/comparison_expression.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/parser/expression/cast_expression.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { ComparisonExpression::ComparisonExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::COMPARISON) { diff --git a/src/duckdb/src/parser/expression/conjunction_expression.cpp b/src/duckdb/src/parser/expression/conjunction_expression.cpp index a11759b8f..0b689e01d 100644 --- a/src/duckdb/src/parser/expression/conjunction_expression.cpp +++ b/src/duckdb/src/parser/expression/conjunction_expression.cpp @@ -2,9 +2,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/parser/expression_util.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { ConjunctionExpression::ConjunctionExpression(ExpressionType type) diff --git a/src/duckdb/src/parser/expression/constant_expression.cpp b/src/duckdb/src/parser/expression/constant_expression.cpp index 437687b14..5e19bc46b 100644 --- a/src/duckdb/src/parser/expression/constant_expression.cpp +++ b/src/duckdb/src/parser/expression/constant_expression.cpp @@ -1,12 +1,8 @@ #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/types/hash.hpp" #include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { ConstantExpression::ConstantExpression() : ParsedExpression(ExpressionType::VALUE_CONSTANT, ExpressionClass::CONSTANT) { diff --git a/src/duckdb/src/parser/expression/default_expression.cpp b/src/duckdb/src/parser/expression/default_expression.cpp index 7618fd21b..7fe7d8f91 100644 --- a/src/duckdb/src/parser/expression/default_expression.cpp +++ b/src/duckdb/src/parser/expression/default_expression.cpp @@ -1,10 +1,5 @@ #include "duckdb/parser/expression/default_expression.hpp" -#include "duckdb/common/exception.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { DefaultExpression::DefaultExpression() : ParsedExpression(ExpressionType::VALUE_DEFAULT, ExpressionClass::DEFAULT) { diff --git a/src/duckdb/src/parser/expression/function_expression.cpp b/src/duckdb/src/parser/expression/function_expression.cpp index 6d96a0cc0..8503fa657 100644 --- a/src/duckdb/src/parser/expression/function_expression.cpp +++ b/src/duckdb/src/parser/expression/function_expression.cpp @@ -5,9 +5,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/types/hash.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { FunctionExpression::FunctionExpression() : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION) { diff --git a/src/duckdb/src/parser/expression/lambda_expression.cpp b/src/duckdb/src/parser/expression/lambda_expression.cpp index d8d4fe891..2b14abd61 100644 --- a/src/duckdb/src/parser/expression/lambda_expression.cpp +++ b/src/duckdb/src/parser/expression/lambda_expression.cpp @@ -1,11 +1,9 @@ #include "duckdb/parser/expression/lambda_expression.hpp" #include "duckdb/common/types/hash.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { @@ -34,7 +32,6 @@ LambdaExpression::LambdaExpression(unique_ptr lhs, unique_ptr< } vector> LambdaExpression::ExtractColumnRefExpressions(string &error_message) const { - // we return an error message because we can't throw a binder exception here, // since we can't distinguish between a lambda function and the JSON operator yet vector> column_refs; diff --git a/src/duckdb/src/parser/expression/lambdaref_expression.cpp b/src/duckdb/src/parser/expression/lambdaref_expression.cpp index fed844fea..f1e7e59bf 100644 --- a/src/duckdb/src/parser/expression/lambdaref_expression.cpp +++ b/src/duckdb/src/parser/expression/lambdaref_expression.cpp @@ -37,7 +37,6 @@ unique_ptr LambdaRefExpression::Copy() const { unique_ptr LambdaRefExpression::FindMatchingBinding(optional_ptr> &lambda_bindings, const string &column_name) { - // if this is a lambda parameter, then we temporarily add a BoundLambdaRef, // which we capture and remove later @@ -47,7 +46,7 @@ LambdaRefExpression::FindMatchingBinding(optional_ptr> &lam if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { if ((*lambda_bindings)[i - 1].HasMatchingBinding(column_name)) { - D_ASSERT((*lambda_bindings)[i - 1].alias.IsSet()); + D_ASSERT((*lambda_bindings)[i - 1].GetBindingAlias().IsSet()); return make_uniq(i - 1, column_name); } } diff --git a/src/duckdb/src/parser/expression/operator_expression.cpp b/src/duckdb/src/parser/expression/operator_expression.cpp index 79d90f674..5b248bb1c 100644 --- a/src/duckdb/src/parser/expression/operator_expression.cpp +++ b/src/duckdb/src/parser/expression/operator_expression.cpp @@ -2,9 +2,6 @@ #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { OperatorExpression::OperatorExpression(ExpressionType type, unique_ptr left, diff --git a/src/duckdb/src/parser/expression/parameter_expression.cpp b/src/duckdb/src/parser/expression/parameter_expression.cpp index 034c4eaef..90149d195 100644 --- a/src/duckdb/src/parser/expression/parameter_expression.cpp +++ b/src/duckdb/src/parser/expression/parameter_expression.cpp @@ -2,10 +2,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/types/hash.hpp" -#include "duckdb/common/to_string.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/expression/positional_reference_expression.cpp b/src/duckdb/src/parser/expression/positional_reference_expression.cpp index e73bbeddf..b1a4d54b5 100644 --- a/src/duckdb/src/parser/expression/positional_reference_expression.cpp +++ b/src/duckdb/src/parser/expression/positional_reference_expression.cpp @@ -4,9 +4,6 @@ #include "duckdb/common/types/hash.hpp" #include "duckdb/common/to_string.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { PositionalReferenceExpression::PositionalReferenceExpression() diff --git a/src/duckdb/src/parser/expression/star_expression.cpp b/src/duckdb/src/parser/expression/star_expression.cpp index f18aee0a9..70ee1e72d 100644 --- a/src/duckdb/src/parser/expression/star_expression.cpp +++ b/src/duckdb/src/parser/expression/star_expression.cpp @@ -4,7 +4,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/expression/subquery_expression.cpp b/src/duckdb/src/parser/expression/subquery_expression.cpp index 2cc1e4548..759675e62 100644 --- a/src/duckdb/src/parser/expression/subquery_expression.cpp +++ b/src/duckdb/src/parser/expression/subquery_expression.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/expression/subquery_expression.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/expression/window_expression.cpp b/src/duckdb/src/parser/expression/window_expression.cpp index 9720d2abf..982e4a425 100644 --- a/src/duckdb/src/parser/expression/window_expression.cpp +++ b/src/duckdb/src/parser/expression/window_expression.cpp @@ -1,12 +1,7 @@ #include "duckdb/parser/expression/window_expression.hpp" -#include "duckdb/common/limits.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { WindowExpression::WindowExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::WINDOW) { @@ -35,31 +30,35 @@ WindowExpression::WindowExpression(ExpressionType type, string catalog_name, str } } +static const WindowFunctionDefinition internal_window_functions[] = { + {"rank", ExpressionType::WINDOW_RANK}, + {"rank_dense", ExpressionType::WINDOW_RANK_DENSE}, + {"dense_rank", ExpressionType::WINDOW_RANK_DENSE}, + {"percent_rank", ExpressionType::WINDOW_PERCENT_RANK}, + {"row_number", ExpressionType::WINDOW_ROW_NUMBER}, + {"first_value", ExpressionType::WINDOW_FIRST_VALUE}, + {"first", ExpressionType::WINDOW_FIRST_VALUE}, + {"last_value", ExpressionType::WINDOW_LAST_VALUE}, + {"last", ExpressionType::WINDOW_LAST_VALUE}, + {"nth_value", ExpressionType::WINDOW_NTH_VALUE}, + {"cume_dist", ExpressionType::WINDOW_CUME_DIST}, + {"lead", ExpressionType::WINDOW_LEAD}, + {"lag", ExpressionType::WINDOW_LAG}, + {"ntile", ExpressionType::WINDOW_NTILE}, + {"fill", ExpressionType::WINDOW_FILL}, + {nullptr, ExpressionType::INVALID}}; + +const WindowFunctionDefinition *WindowExpression::WindowFunctions() { + return internal_window_functions; +} + ExpressionType WindowExpression::WindowToExpressionType(string &fun_name) { - if (fun_name == "rank") { - return ExpressionType::WINDOW_RANK; - } else if (fun_name == "rank_dense" || fun_name == "dense_rank") { - return ExpressionType::WINDOW_RANK_DENSE; - } else if (fun_name == "percent_rank") { - return ExpressionType::WINDOW_PERCENT_RANK; - } else if (fun_name == "row_number") { - return ExpressionType::WINDOW_ROW_NUMBER; - } else if (fun_name == "first_value" || fun_name == "first") { - return ExpressionType::WINDOW_FIRST_VALUE; - } else if (fun_name == "last_value" || fun_name == "last") { - return ExpressionType::WINDOW_LAST_VALUE; - } else if (fun_name == "nth_value") { - return ExpressionType::WINDOW_NTH_VALUE; - } else if (fun_name == "cume_dist") { - return ExpressionType::WINDOW_CUME_DIST; - } else if (fun_name == "lead") { - return ExpressionType::WINDOW_LEAD; - } else if (fun_name == "lag") { - return ExpressionType::WINDOW_LAG; - } else if (fun_name == "ntile") { - return ExpressionType::WINDOW_NTILE; - } else if (fun_name == "fill") { - return ExpressionType::WINDOW_FILL; + D_ASSERT(StringUtil::IsLower(fun_name)); + auto functions = WindowFunctions(); + for (idx_t i = 0; functions[i].name != nullptr; i++) { + if (fun_name == functions[i].name) { + return functions[i].expression_type; + } } return ExpressionType::WINDOW_AGGREGATE; } diff --git a/src/duckdb/src/parser/parsed_data/alter_info.cpp b/src/duckdb/src/parser/parsed_data/alter_info.cpp index 2f90d0abf..671f27663 100644 --- a/src/duckdb/src/parser/parsed_data/alter_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_info.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/parsed_data/alter_info.hpp" -#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" #include "duckdb/parser/parsed_data/alter_table_info.hpp" #include "duckdb/parser/constraints/unique_constraint.hpp" diff --git a/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp b/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp index 3de4fc52a..7a596dbbb 100644 --- a/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_scalar_function_info.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/constraint.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp b/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp index e7ce608c8..e5e8a88ac 100644 --- a/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_table_function_info.cpp @@ -1,7 +1,5 @@ #include "duckdb/parser/parsed_data/alter_table_function_info.hpp" -#include "duckdb/parser/constraint.hpp" - namespace duckdb { //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/parser/parsed_data/alter_table_info.cpp b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp index 01bdc4da9..8d50e3dd6 100644 --- a/src/duckdb/src/parser/parsed_data/alter_table_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp @@ -1,7 +1,7 @@ #include "duckdb/parser/parsed_data/alter_table_info.hpp" #include "duckdb/common/extra_type_info.hpp" - #include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/attach_info.cpp b/src/duckdb/src/parser/parsed_data/attach_info.cpp index 333c27c30..e9316c2ad 100644 --- a/src/duckdb/src/parser/parsed_data/attach_info.cpp +++ b/src/duckdb/src/parser/parsed_data/attach_info.cpp @@ -1,8 +1,5 @@ #include "duckdb/parser/parsed_data/attach_info.hpp" #include "duckdb/parser/keyword_helper.hpp" - -#include "duckdb/storage/storage_info.hpp" -#include "duckdb/common/optional_idx.hpp" #include "duckdb/main/config.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/create_info.cpp b/src/duckdb/src/parser/parsed_data/create_info.cpp index 7f25b8c76..ff7b58575 100644 --- a/src/duckdb/src/parser/parsed_data/create_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_info.cpp @@ -1,11 +1,6 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/parsed_data/create_index_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" #include "duckdb/parser/parsed_data/create_macro_info.hpp" diff --git a/src/duckdb/src/parser/parsed_data/create_macro_info.cpp b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp index 9c629d395..64b979cb4 100644 --- a/src/duckdb/src/parser/parsed_data/create_macro_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp @@ -1,5 +1,4 @@ #include "duckdb/parser/parsed_data/create_macro_info.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/parser/keyword_helper.hpp" diff --git a/src/duckdb/src/parser/parsed_data/create_schema_info.cpp b/src/duckdb/src/parser/parsed_data/create_schema_info.cpp index e7c7f3f8b..8eb6faf2d 100644 --- a/src/duckdb/src/parser/parsed_data/create_schema_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_schema_info.cpp @@ -1,5 +1,4 @@ #include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp index 560ff3d40..76728e669 100644 --- a/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp @@ -1,7 +1,4 @@ #include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/create_type_info.cpp b/src/duckdb/src/parser/parsed_data/create_type_info.cpp index 2f3ec9000..c48799402 100644 --- a/src/duckdb/src/parser/parsed_data/create_type_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_type_info.cpp @@ -1,7 +1,4 @@ #include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog.hpp" #include "duckdb/common/extra_type_info.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/sample_options.cpp b/src/duckdb/src/parser/parsed_data/sample_options.cpp index 54be9d1cb..1dfcba72f 100644 --- a/src/duckdb/src/parser/parsed_data/sample_options.cpp +++ b/src/duckdb/src/parser/parsed_data/sample_options.cpp @@ -1,6 +1,6 @@ #include "duckdb/parser/parsed_data/sample_options.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/to_string.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_data/vacuum_info.cpp b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp index 7fff15e83..2ba43d2fc 100644 --- a/src/duckdb/src/parser/parsed_data/vacuum_info.cpp +++ b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp @@ -1,4 +1,5 @@ #include "duckdb/parser/parsed_data/vacuum_info.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/parsed_expression.cpp b/src/duckdb/src/parser/parsed_expression.cpp index 1ba33bfd8..2732e760c 100644 --- a/src/duckdb/src/parser/parsed_expression.cpp +++ b/src/duckdb/src/parser/parsed_expression.cpp @@ -4,7 +4,6 @@ #include "duckdb/common/types/hash.hpp" #include "duckdb/parser/expression/list.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" -#include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/parser/expression_util.hpp" diff --git a/src/duckdb/src/parser/parsed_expression_iterator.cpp b/src/duckdb/src/parser/parsed_expression_iterator.cpp index 7ca38a10e..f5746f9f7 100644 --- a/src/duckdb/src/parser/parsed_expression_iterator.cpp +++ b/src/duckdb/src/parser/parsed_expression_iterator.cpp @@ -162,7 +162,6 @@ void ParsedExpressionIterator::EnumerateChildren( void ParsedExpressionIterator::EnumerateQueryNodeModifiers( QueryNode &node, const std::function &child)> &callback) { - for (auto &modifier : node.modifiers) { switch (modifier->type) { case ResultModifierType::LIMIT_MODIFIER: { @@ -271,12 +270,6 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( EnumerateQueryNodeChildren(*rcte_node.right, expr_callback, ref_callback); break; } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.query, expr_callback, ref_callback); - EnumerateQueryNodeChildren(*cte_node.child, expr_callback, ref_callback); - break; - } case QueryNodeType::SELECT_NODE: { auto &sel_node = node.Cast(); for (idx_t i = 0; i < sel_node.select_list.size(); i++) { diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp index 552b6e180..22649ca8d 100644 --- a/src/duckdb/src/parser/parser.cpp +++ b/src/duckdb/src/parser/parser.cpp @@ -1,10 +1,8 @@ #include "duckdb/parser/parser.hpp" -#include "duckdb/parser/expression/cast_expression.hpp" #include "duckdb/parser/group_by_node.hpp" #include "duckdb/parser/parsed_data/create_table_info.hpp" #include "duckdb/parser/parser_extension.hpp" -#include "duckdb/parser/query_error_context.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/statement/create_statement.hpp" #include "duckdb/parser/statement/extension_statement.hpp" @@ -44,8 +42,8 @@ static bool ReplaceUnicodeSpaces(const string &query, string &new_query, vector< } static bool IsValidDollarQuotedStringTagFirstChar(const unsigned char &c) { - // the first character can be between A-Z, a-z, or \200 - \377 - return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c >= 0x80; + // the first character can be between A-Z, a-z, underscore, or \200 - \377 + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_' || c >= 0x80; } static bool IsValidDollarQuotedStringTagSubsequentChar(const unsigned char &c) { @@ -165,33 +163,74 @@ bool Parser::StripUnicodeSpaces(const string &query_str, string &new_query) { return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces); } -vector SplitQueryStringIntoStatements(const string &query) { - // Break sql string down into sql statements using the tokenizer - vector query_statements; - auto tokens = Parser::Tokenize(query); - idx_t next_statement_start = 0; - for (idx_t i = 1; i < tokens.size(); ++i) { - auto &t_prev = tokens[i - 1]; - auto &t = tokens[i]; - if (t_prev.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR) { - // LCOV_EXCL_START - for (idx_t c = t_prev.start; c <= t.start; ++c) { - if (query.c_str()[c] == ';') { - query_statements.emplace_back(query.substr(next_statement_start, t.start - next_statement_start)); - next_statement_start = tokens[i].start; - } +vector SplitQueries(const string &input_query) { + vector queries; + auto tokenized_input = Parser::Tokenize(input_query); + size_t last_split = 0; + + for (const auto &token : tokenized_input) { + if (token.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR && input_query[token.start] == ';') { + string segment = input_query.substr(last_split, token.start - last_split); + StringUtil::Trim(segment); + if (!segment.empty()) { + segment.append(";"); + queries.push_back(std::move(segment)); } - // LCOV_EXCL_STOP + last_split = token.start + 1; } } - query_statements.emplace_back(query.substr(next_statement_start, query.size() - next_statement_start)); - return query_statements; + string final_segment = input_query.substr(last_split); + StringUtil::Trim(final_segment); + if (!final_segment.empty()) { + final_segment.append(";"); + queries.push_back(std::move(final_segment)); + } + return queries; +} + +unique_ptr Parser::GetStatement(const string &query) { + Transformer transformer(options); + vector> statements; + PostgresParser parser; + parser.Parse(query); + if (parser.success) { + if (!parser.parse_tree) { + // empty statement + return {}; + } + transformer.TransformParseTree(parser.parse_tree, statements); + return std::move(statements[0]); + } + return {}; +} + +void Parser::ThrowParserOverrideError(ParserOverrideResult &result) { + if (result.type == ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR) { + throw ParserException("Parser override failed to return a valid statement: %s\n\nConsider restarting the " + "database and " + "using the setting \"set allow_parser_override_extension=fallback\" to fallback to the " + "default parser.", + result.error.RawMessage()); + } + if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { + if (result.error.Type() == ExceptionType::NOT_IMPLEMENTED) { + throw NotImplementedException("Parser override has not yet implemented this " + "transformer rule.\nOriginal error: %s", + result.error.RawMessage()); + } + if (result.error.Type() == ExceptionType::PARSER) { + throw ParserException("Parser override could not parse this query.\nOriginal error: %s", + result.error.RawMessage()); + } + result.error.Throw(); + } } void Parser::ParseQuery(const string &query) { Transformer transformer(options); string parser_error; optional_idx parser_error_location; + string parser_override_option = StringUtil::Lower(options.parser_override_setting); { // check if there are any unicode spaces in the string string new_query; @@ -207,12 +246,62 @@ void Parser::ParseQuery(const string &query) { if (!ext.parser_override) { continue; } + if (StringUtil::CIEquals(parser_override_option, "default")) { + continue; + } auto result = ext.parser_override(ext.parser_info.get(), query); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { statements = std::move(result.statements); return; - } else if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { - throw ParserException(result.error); + } + if (StringUtil::CIEquals(parser_override_option, "strict")) { + ThrowParserOverrideError(result); + } + if (StringUtil::CIEquals(parser_override_option, "strict_when_supported")) { + auto statement = GetStatement(query); + if (!statement) { + break; + } + bool is_supported = false; + switch (statement->type) { + case StatementType::CALL_STATEMENT: + case StatementType::TRANSACTION_STATEMENT: + case StatementType::VARIABLE_SET_STATEMENT: + case StatementType::LOAD_STATEMENT: + case StatementType::ATTACH_STATEMENT: + case StatementType::DETACH_STATEMENT: + case StatementType::DELETE_STATEMENT: + case StatementType::DROP_STATEMENT: + case StatementType::ALTER_STATEMENT: + case StatementType::PRAGMA_STATEMENT: + case StatementType::COPY_DATABASE_STATEMENT: + is_supported = true; + break; + case StatementType::CREATE_STATEMENT: { + auto &create_statement = statement->Cast(); + switch (create_statement.info->type) { + case CatalogType::INDEX_ENTRY: + case CatalogType::MACRO_ENTRY: + case CatalogType::SCHEMA_ENTRY: + case CatalogType::SECRET_ENTRY: + case CatalogType::SEQUENCE_ENTRY: + case CatalogType::TYPE_ENTRY: + is_supported = true; + break; + default: + is_supported = false; + } + break; + } + default: + is_supported = false; + break; + } + if (is_supported) { + ThrowParserOverrideError(result); + } + } else if (StringUtil::CIEquals(parser_override_option, "fallback")) { + continue; } } } @@ -250,9 +339,9 @@ void Parser::ParseQuery(const string &query) { throw ParserException::SyntaxError(query, parser_error, parser_error_location); } else { // split sql string into statements and re-parse using extension - auto query_statements = SplitQueryStringIntoStatements(query); + auto queries = SplitQueries(query); idx_t stmt_loc = 0; - for (auto const &query_statement : query_statements) { + for (auto const &query_statement : queries) { ErrorData another_parser_error; // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant { @@ -284,7 +373,9 @@ void Parser::ParseQuery(const string &query) { bool parsed_single_statement = false; for (auto &ext : *options.extensions) { D_ASSERT(!parsed_single_statement); - D_ASSERT(ext.parse_function); + if (!ext.parse_function) { + continue; + } auto result = ext.parse_function(ext.parser_info.get(), query_statement); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { auto statement = make_uniq(ext, std::move(result.parse_data)); @@ -362,18 +453,23 @@ vector Parser::TokenizeError(const string &error_msg) { vector tokens; // find "XXX Error:" - this marks the start of the error message - auto error = StringUtil::Find(error_msg, "Error: "); + auto error = StringUtil::Find(error_msg, "Error:"); if (error.IsValid()) { SimplifiedToken token; - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_EMPHASIS; token.start = 0; tokens.push_back(token); - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; token.start = error.GetIndex() + 6; tokens.push_back(token); - error_start = error.GetIndex() + 7; + error_start = error.GetIndex() + 6; + } else { + SimplifiedToken token; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; + token.start = 0; + tokens.push_back(token); } // find "LINE (number)" - this marks the end of the message @@ -392,7 +488,7 @@ vector Parser::TokenizeError(const string &error_msg) { if (error_msg[i] == quote_char) { SimplifiedToken token; token.start = i; - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; tokens.push_back(token); in_quotes = false; } @@ -405,7 +501,7 @@ vector Parser::TokenizeError(const string &error_msg) { // not quoted and found a quote - enter the quoted state SimplifiedToken token; token.start = i; - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_SUGGESTION; token.start++; tokens.push_back(token); quote_char = error_msg[i]; @@ -419,7 +515,7 @@ vector Parser::TokenizeError(const string &error_msg) { if (line_pos.IsValid()) { SimplifiedToken token; token.start = line_pos.GetIndex() + 1; - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_EMPHASIS; tokens.push_back(token); // tokenize the LINE part @@ -431,7 +527,7 @@ vector Parser::TokenizeError(const string &error_msg) { } if (query_start < error_msg.size()) { token.start = query_start; - token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER; + token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; tokens.push_back(token); idx_t query_end; @@ -455,25 +551,34 @@ vector Parser::TokenizeError(const string &error_msg) { } } } + // tokenize the actual query string query = error_msg.substr(query_start, query_end - query_start); auto query_tokens = Tokenize(query); for (auto &query_token : query_tokens) { if (place_caret) { + // find the caret position and highlight the identifier it points to if (query_token.start >= caret_position) { // we need to place the caret here query_token.start = query_start + caret_position; - query_token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; + query_token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_EMPHASIS; tokens.push_back(query_token); place_caret = false; continue; } } + switch (query_token.type) { + case SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD: + query_token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR_EMPHASIS; + break; + default: + query_token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; + break; + } query_token.start += query_start; tokens.push_back(query_token); } - // FIXME: find the caret position and highlight/bold the identifier it points to if (query_end < error_msg.size()) { token.start = query_end; token.type = SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR; diff --git a/src/duckdb/src/parser/qualified_name.cpp b/src/duckdb/src/parser/qualified_name.cpp index febb8fccc..3d5759c25 100644 --- a/src/duckdb/src/parser/qualified_name.cpp +++ b/src/duckdb/src/parser/qualified_name.cpp @@ -1,5 +1,6 @@ #include "duckdb/parser/qualified_name.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/query_error_context.cpp b/src/duckdb/src/parser/query_error_context.cpp index 44defbce0..2226802de 100644 --- a/src/duckdb/src/parser/query_error_context.cpp +++ b/src/duckdb/src/parser/query_error_context.cpp @@ -1,8 +1,7 @@ #include "duckdb/parser/query_error_context.hpp" -#include "duckdb/parser/sql_statement.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" #include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/to_string.hpp" #include "utf8proc_wrapper.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/query_node/cte_node.cpp b/src/duckdb/src/parser/query_node/cte_node.cpp index 1e1f0e199..d23e42491 100644 --- a/src/duckdb/src/parser/query_node/cte_node.cpp +++ b/src/duckdb/src/parser/query_node/cte_node.cpp @@ -1,42 +1,17 @@ #include "duckdb/parser/query_node/cte_node.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { string CTENode::ToString() const { - string result; - result += child->ToString(); - return result; + throw InternalException("CTENode is a legacy type"); } bool CTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (!query->Equals(other.query.get())) { - return false; - } - if (!child->Equals(other.child.get())) { - return false; - } - return true; + throw InternalException("CTENode is a legacy type"); } unique_ptr CTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->query = query->Copy(); - result->child = child->Copy(); - result->aliases = aliases; - result->materialized = materialized; - this->CopyProperties(*result); - return std::move(result); + throw InternalException("CTENode is a legacy type"); } } // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/recursive_cte_node.cpp b/src/duckdb/src/parser/query_node/recursive_cte_node.cpp index 993c38f6b..eb6c1e009 100644 --- a/src/duckdb/src/parser/query_node/recursive_cte_node.cpp +++ b/src/duckdb/src/parser/query_node/recursive_cte_node.cpp @@ -1,5 +1,4 @@ #include "duckdb/parser/query_node/recursive_cte_node.hpp" -#include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/query_node/select_node.cpp b/src/duckdb/src/parser/query_node/select_node.cpp index 71c228c40..addf992f3 100644 --- a/src/duckdb/src/parser/query_node/select_node.cpp +++ b/src/duckdb/src/parser/query_node/select_node.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression_util.hpp" -#include "duckdb/parser/keyword_helper.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/query_node/set_operation_node.cpp b/src/duckdb/src/parser/query_node/set_operation_node.cpp index a8b624f21..30d36defc 100644 --- a/src/duckdb/src/parser/query_node/set_operation_node.cpp +++ b/src/duckdb/src/parser/query_node/set_operation_node.cpp @@ -1,17 +1,11 @@ #include "duckdb/parser/query_node/set_operation_node.hpp" - #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { SetOperationNode::SetOperationNode() : QueryNode(QueryNodeType::SET_OPERATION_NODE) { } -const vector> &SetOperationNode::GetSelectList() const { - return children[0]->GetSelectList(); -} - string SetOperationNode::ToString() const { string result; result = cte_map.ToString(); diff --git a/src/duckdb/src/parser/query_node/statement_node.cpp b/src/duckdb/src/parser/query_node/statement_node.cpp new file mode 100644 index 000000000..66e7b8e5a --- /dev/null +++ b/src/duckdb/src/parser/query_node/statement_node.cpp @@ -0,0 +1,40 @@ +#include "duckdb/parser/query_node/statement_node.hpp" + +namespace duckdb { + +StatementNode::StatementNode(SQLStatement &stmt_p) : QueryNode(QueryNodeType::STATEMENT_NODE), stmt(stmt_p) { +} + +//! Convert the query node to a string +string StatementNode::ToString() const { + return stmt.ToString(); +} + +bool StatementNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + return RefersToSameObject(stmt, other.stmt); +} + +//! Create a copy of this SelectNode +unique_ptr StatementNode::Copy() const { + return make_uniq(stmt); +} + +//! Serializes a QueryNode to a stand-alone binary blob +//! Deserializes a blob back into a QueryNode + +void StatementNode::Serialize(Serializer &serializer) const { + throw InternalException("StatementNode cannot be serialized"); +} + +unique_ptr StatementNode::Deserialize(Deserializer &source) { + throw InternalException("StatementNode cannot be deserialized"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/result_modifier.cpp b/src/duckdb/src/parser/result_modifier.cpp index eae317e4c..e22108aea 100644 --- a/src/duckdb/src/parser/result_modifier.cpp +++ b/src/duckdb/src/parser/result_modifier.cpp @@ -1,7 +1,5 @@ #include "duckdb/parser/result_modifier.hpp" #include "duckdb/parser/expression_util.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/statement/export_statement.cpp b/src/duckdb/src/parser/statement/export_statement.cpp index ad28634a0..decef215a 100644 --- a/src/duckdb/src/parser/statement/export_statement.cpp +++ b/src/duckdb/src/parser/statement/export_statement.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/statement/export_statement.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" -#include "duckdb/parser/query_node.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/statement/relation_statement.cpp b/src/duckdb/src/parser/statement/relation_statement.cpp index 9b3801495..023d3cac9 100644 --- a/src/duckdb/src/parser/statement/relation_statement.cpp +++ b/src/duckdb/src/parser/statement/relation_statement.cpp @@ -5,10 +5,7 @@ namespace duckdb { RelationStatement::RelationStatement(shared_ptr relation_p) : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation_p)) { - if (relation->type == RelationType::QUERY_RELATION) { - auto &query_relation = relation->Cast(); - query = query_relation.query; - } + query = relation->GetQuery(); } unique_ptr RelationStatement::Copy() const { diff --git a/src/duckdb/src/parser/statement/select_statement.cpp b/src/duckdb/src/parser/statement/select_statement.cpp index 1c686c675..77002a9e2 100644 --- a/src/duckdb/src/parser/statement/select_statement.cpp +++ b/src/duckdb/src/parser/statement/select_statement.cpp @@ -1,8 +1,5 @@ #include "duckdb/parser/statement/select_statement.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { SelectStatement::SelectStatement(const SelectStatement &other) : SQLStatement(other), node(other.node->Copy()) { diff --git a/src/duckdb/src/parser/statement/update_statement.cpp b/src/duckdb/src/parser/statement/update_statement.cpp index b09fd1a0d..115e76d7f 100644 --- a/src/duckdb/src/parser/statement/update_statement.cpp +++ b/src/duckdb/src/parser/statement/update_statement.cpp @@ -1,5 +1,4 @@ #include "duckdb/parser/statement/update_statement.hpp" -#include "duckdb/parser/query_node/select_node.hpp" namespace duckdb { @@ -49,7 +48,6 @@ UpdateStatement::UpdateStatement(const UpdateStatement &other) } string UpdateStatement::ToString() const { - string result; result = cte_map.ToString(); result += "UPDATE "; diff --git a/src/duckdb/src/parser/tableref.cpp b/src/duckdb/src/parser/tableref.cpp index b97f7ecb4..04ec1fd87 100644 --- a/src/duckdb/src/parser/tableref.cpp +++ b/src/duckdb/src/parser/tableref.cpp @@ -1,9 +1,8 @@ #include "duckdb/parser/tableref.hpp" #include "duckdb/common/printer.hpp" -#include "duckdb/parser/tableref/list.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/enum_util.hpp" #include "duckdb/common/to_string.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp index a2ecb5086..6bf31269c 100644 --- a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp +++ b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp @@ -2,7 +2,7 @@ namespace duckdb { -BoundRefWrapper::BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p) +BoundRefWrapper::BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p) : TableRef(TableReferenceType::BOUND_TABLE_REF), bound_ref(std::move(bound_ref_p)), binder(std::move(binder_p)) { } diff --git a/src/duckdb/src/parser/tableref/column_data_ref.cpp b/src/duckdb/src/parser/tableref/column_data_ref.cpp index 766c5b04e..6336e9204 100644 --- a/src/duckdb/src/parser/tableref/column_data_ref.cpp +++ b/src/duckdb/src/parser/tableref/column_data_ref.cpp @@ -1,9 +1,6 @@ #include "duckdb/parser/tableref/column_data_ref.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { ColumnDataRef::ColumnDataRef(optionally_owned_ptr collection_p, vector expected_names) diff --git a/src/duckdb/src/parser/tableref/expressionlistref.cpp b/src/duckdb/src/parser/tableref/expressionlistref.cpp index c14483a1b..800dac4be 100644 --- a/src/duckdb/src/parser/tableref/expressionlistref.cpp +++ b/src/duckdb/src/parser/tableref/expressionlistref.cpp @@ -1,8 +1,5 @@ #include "duckdb/parser/tableref/expressionlistref.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - namespace duckdb { string ExpressionListRef::ToString() const { diff --git a/src/duckdb/src/parser/tableref/joinref.cpp b/src/duckdb/src/parser/tableref/joinref.cpp index a36af4e46..a738cc5b6 100644 --- a/src/duckdb/src/parser/tableref/joinref.cpp +++ b/src/duckdb/src/parser/tableref/joinref.cpp @@ -2,7 +2,6 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/tableref/pivotref.cpp b/src/duckdb/src/parser/tableref/pivotref.cpp index ffeef36f4..8874fdee7 100644 --- a/src/duckdb/src/parser/tableref/pivotref.cpp +++ b/src/duckdb/src/parser/tableref/pivotref.cpp @@ -1,5 +1,5 @@ #include "duckdb/parser/tableref/pivotref.hpp" - +#include "duckdb/parser/expression_util.hpp" #include "duckdb/common/limits.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/tableref/subqueryref.cpp b/src/duckdb/src/parser/tableref/subqueryref.cpp index 2d3214bdc..4d9d09880 100644 --- a/src/duckdb/src/parser/tableref/subqueryref.cpp +++ b/src/duckdb/src/parser/tableref/subqueryref.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/common/limits.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/tableref/table_function.cpp b/src/duckdb/src/parser/tableref/table_function.cpp index 547eacc2d..735623a9a 100644 --- a/src/duckdb/src/parser/tableref/table_function.cpp +++ b/src/duckdb/src/parser/tableref/table_function.cpp @@ -1,7 +1,5 @@ #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/common/vector.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp b/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp index b31f33981..256a10200 100644 --- a/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp +++ b/src/duckdb/src/parser/transform/constraint/transform_constraint.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/constraints/list.hpp" #include "duckdb/parser/transformer.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/transform/expression/transform_array_access.cpp b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp index 447688c61..69c38eac0 100644 --- a/src/duckdb/src/parser/transform/expression/transform_array_access.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_array_access.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/operator_expression.hpp" @@ -7,7 +8,6 @@ namespace duckdb { unique_ptr Transformer::TransformArrayAccess(duckdb_libpgquery::PGAIndirection &indirection_node) { - // Transform the source expression. unique_ptr result; result = TransformExpression(indirection_node.arg); diff --git a/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp b/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp index 76078f5e2..1d5bfd219 100644 --- a/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_bool_expr.cpp @@ -33,10 +33,13 @@ unique_ptr Transformer::TransformBoolExpr(duckdb_libpgquery::P // convert COMPARE_IN to COMPARE_NOT_IN next->SetExpressionTypeUnsafe(ExpressionType::COMPARE_NOT_IN); result = std::move(next); - } else if (next->GetExpressionType() >= ExpressionType::COMPARE_EQUAL && - next->GetExpressionType() <= ExpressionType::COMPARE_GREATERTHANOREQUALTO) { + } else if ((next->GetExpressionType() >= ExpressionType::COMPARE_EQUAL && + next->GetExpressionType() <= ExpressionType::COMPARE_GREATERTHANOREQUALTO) || + next->GetExpressionType() == ExpressionType::COMPARE_DISTINCT_FROM || + next->GetExpressionType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { // NOT on a comparison: we can negate the comparison // e.g. NOT(x > y) is equivalent to x <= y + // NOT(x IS DISTINCT FROM y) is equivalent to x IS NOT DISTINCT FROM y next->SetExpressionTypeUnsafe(NegateComparisonExpression(next->GetExpressionType())); result = std::move(next); } else { diff --git a/src/duckdb/src/parser/transform/expression/transform_cast.cpp b/src/duckdb/src/parser/transform/expression/transform_cast.cpp index a4b1dde59..0412a3c96 100644 --- a/src/duckdb/src/parser/transform/expression/transform_cast.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_cast.cpp @@ -21,7 +21,9 @@ unique_ptr Transformer::TransformTypeCast(duckdb_libpgquery::P parameters.query_location = NumericCast(root.location); } auto blob_data = Blob::ToBlob(string(c->val.val.str), parameters); - return make_uniq(Value::BLOB_RAW(blob_data)); + auto result = make_uniq(Value::BLOB_RAW(blob_data)); + SetQueryLocation(*result, root.location); + return std::move(result); } } // transform the expression node diff --git a/src/duckdb/src/parser/transform/expression/transform_expression.cpp b/src/duckdb/src/parser/transform/expression/transform_expression.cpp index 8cbf53b70..73b42c9dd 100644 --- a/src/duckdb/src/parser/transform/expression/transform_expression.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_expression.cpp @@ -16,7 +16,6 @@ unique_ptr Transformer::TransformResTarget(duckdb_libpgquery:: } unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::PGNamedArgExpr &root) { - auto expr = TransformExpression(PGPointerCast(root.arg)); if (root.name) { expr->SetAlias(root.name); @@ -25,7 +24,6 @@ unique_ptr Transformer::TransformNamedArg(duckdb_libpgquery::P } unique_ptr Transformer::TransformExpression(duckdb_libpgquery::PGNode &node) { - auto stack_checker = StackCheck(); switch (node.type) { diff --git a/src/duckdb/src/parser/transform/expression/transform_function.cpp b/src/duckdb/src/parser/transform/expression/transform_function.cpp index b1993643c..574be3e2e 100644 --- a/src/duckdb/src/parser/transform/expression/transform_function.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_function.cpp @@ -38,7 +38,6 @@ void Transformer::TransformWindowDef(duckdb_libpgquery::PGWindowDef &window_spec static inline WindowBoundary TransformFrameOption(const int frameOptions, const WindowBoundary rows, const WindowBoundary range, const WindowBoundary groups) { - if (frameOptions & FRAMEOPTION_RANGE) { return range; } else if (frameOptions & FRAMEOPTION_GROUPS) { diff --git a/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp index 28dc623f3..4ad2a0de3 100644 --- a/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_multi_assign_reference.cpp @@ -4,7 +4,6 @@ namespace duckdb { unique_ptr Transformer::TransformMultiAssignRef(duckdb_libpgquery::PGMultiAssignRef &root) { - // Early-out, if the root is not a function call. if (root.source->type != duckdb_libpgquery::T_PGFuncCall) { return TransformExpression(root.source); diff --git a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp index bc8a9762d..986e46e25 100644 --- a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp @@ -24,7 +24,6 @@ unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::P subquery_expr->subquery = TransformSelectStmt(*root.subselect); SetQueryLocation(*subquery_expr, root.location); D_ASSERT(subquery_expr->subquery); - D_ASSERT(!subquery_expr->subquery->node->GetSelectList().empty()); switch (root.subLinkType) { case duckdb_libpgquery::PG_EXISTS_SUBLINK: { diff --git a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp index 2de5d8334..8280149cd 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/enums/set_operation_type.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/parser/statement/select_statement.hpp" @@ -23,9 +24,16 @@ unique_ptr CommonTableExpressionInfo::Copy() { CommonTableExpressionInfo::~CommonTableExpressionInfo() { } +CTEMaterialize CommonTableExpressionInfo::GetMaterializedForSerialization(Serializer &serializer) const { + if (serializer.ShouldSerialize(7)) { + return materialized; + } + return CTEMaterialize::CTE_MATERIALIZE_DEFAULT; +} + void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { for (auto &cte_entry : stored_cte_map) { - for (auto &entry : cte_entry->map) { + for (auto &entry : cte_entry.get().map) { auto found_entry = cte_map.map.find(entry.first); if (found_entry != cte_map.map.end()) { // entry already present - use top-most entry @@ -40,7 +48,7 @@ void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { } void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map) { - stored_cte_map.push_back(&cte_map); + stored_cte_map.push_back(cte_map); // TODO: might need to update in case of future lawsuit D_ASSERT(de_with_clause.ctes); diff --git a/src/duckdb/src/parser/transform/helpers/transform_sample.cpp b/src/duckdb/src/parser/transform/helpers/transform_sample.cpp index bd1cc75a2..c563957bc 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_sample.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_sample.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/string_util.hpp" namespace duckdb { +constexpr idx_t SampleOptions::MAX_SAMPLE_ROWS; static SampleMethod GetSampleMethod(const string &method) { auto lmethod = StringUtil::Lower(method); @@ -44,8 +45,9 @@ unique_ptr Transformer::TransformSampleOptions(optional_ptr(); - if (rows < 0) { - throw ParserException("Sample rows %lld out of range, must be bigger than or equal to 0", rows); + if (rows < 0 || sample_value.GetValue() > SampleOptions::MAX_SAMPLE_ROWS) { + throw ParserException("Sample rows %lld out of range, must be between 0 and %lld", rows, + SampleOptions::MAX_SAMPLE_ROWS); } result->sample_size = Value::BIGINT(rows); result->method = SampleMethod::RESERVOIR_SAMPLE; diff --git a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp index c071af00b..86ce2a73c 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp @@ -17,7 +17,6 @@ struct SizeModifiers { }; static SizeModifiers GetSizeModifiers(duckdb_libpgquery::PGTypeName &type_name, LogicalTypeId base_type) { - SizeModifiers result; if (base_type == LogicalTypeId::DECIMAL) { @@ -97,6 +96,11 @@ LogicalType Transformer::TransformTypeNameInternal(duckdb_libpgquery::PGTypeName // transform it to the SQL type LogicalTypeId base_type = TransformStringToLogicalTypeId(name); + if (base_type == LogicalTypeId::GEOMETRY) { + // Always return a type with GeoTypeInfo + return LogicalType::GEOMETRY(); + } + if (base_type == LogicalTypeId::LIST) { throw ParserException("LIST is not valid as a stand-alone type"); } diff --git a/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp index 7fb15a90d..8e676f3c8 100644 --- a/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_alter_table.cpp @@ -3,6 +3,7 @@ #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/statement/alter_statement.hpp" #include "duckdb/parser/transformer.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { @@ -19,7 +20,6 @@ vector Transformer::TransformNameList(duckdb_libpgquery::PGList &list) { } unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlterTableStmt &stmt) { - D_ASSERT(stmt.relation); if (stmt.cmds->length != 1) { throw ParserException("Only one ALTER command per statement is supported"); @@ -30,7 +30,6 @@ unique_ptr Transformer::TransformAlter(duckdb_libpgquery::PGAlte // Check the ALTER type. for (auto c = stmt.cmds->head; c != nullptr; c = c->next) { - auto command = PGPointerCast(c->data.ptr_value); AlterEntryData data(qualified_name.catalog, qualified_name.schema, qualified_name.name, TransformOnEntryNotFound(stmt.missing_ok)); diff --git a/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp b/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp index e8a9d83ec..16ab1b596 100644 --- a/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_create_table_as.cpp @@ -21,6 +21,9 @@ unique_ptr Transformer::TransformCreateTableAs(duckdb_libpgquer auto result = make_uniq(); auto info = make_uniq(); auto qname = TransformQualifiedName(*stmt.into->rel); + if (qname.name.empty()) { + throw ParserException("Empty table name not supported"); + } auto query = TransformSelectStmt(*stmt.query, false); // push a LIMIT 0 if 'WITH NO DATA' is specified diff --git a/src/duckdb/src/parser/transform/statement/transform_explain.cpp b/src/duckdb/src/parser/transform/statement/transform_explain.cpp index 510395a99..969a8827b 100644 --- a/src/duckdb/src/parser/transform/statement/transform_explain.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_explain.cpp @@ -11,7 +11,8 @@ ExplainFormat ParseFormat(const Value &val) { auto format_val = val.GetValue(); case_insensitive_map_t format_mapping { {"default", ExplainFormat::DEFAULT}, {"text", ExplainFormat::TEXT}, {"json", ExplainFormat::JSON}, - {"html", ExplainFormat::HTML}, {"graphviz", ExplainFormat::GRAPHVIZ}, {"yaml", ExplainFormat::YAML}}; + {"html", ExplainFormat::HTML}, {"graphviz", ExplainFormat::GRAPHVIZ}, {"yaml", ExplainFormat::YAML}, + {"mermaid", ExplainFormat::MERMAID}}; auto it = format_mapping.find(format_val); if (it != format_mapping.end()) { return it->second; diff --git a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp index 07dfb420d..4572a3a36 100644 --- a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp @@ -95,7 +95,7 @@ unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr(); - select->node = TransformMaterializedCTE(std::move(subselect)); + select->node = std::move(subselect); info->query = std::move(select); info->type = LogicalType::INVALID; diff --git a/src/duckdb/src/parser/transform/statement/transform_select.cpp b/src/duckdb/src/parser/transform/statement/transform_select.cpp index 2e5135ef6..16cd1a490 100644 --- a/src/duckdb/src/parser/transform/statement/transform_select.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_select.cpp @@ -26,13 +26,10 @@ unique_ptr Transformer::TransformSelectNodeInternal(duckdb_libpgquery throw ParserException("SELECT locking clause is not supported!"); } } - unique_ptr stmt = nullptr; if (select.pivot) { - stmt = TransformPivotStatement(select); - } else { - stmt = TransformSelectInternal(select); + return TransformPivotStatement(select); } - return TransformMaterializedCTE(std::move(stmt)); + return TransformSelectInternal(select); } unique_ptr Transformer::TransformSelectStmt(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { diff --git a/src/duckdb/src/parser/transform/statement/transform_upsert.cpp b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp index aa0130f3c..8d5fdaf35 100644 --- a/src/duckdb/src/parser/transform/statement/transform_upsert.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_upsert.cpp @@ -67,7 +67,6 @@ unique_ptr Transformer::DummyOnConflictClause(duckdb_libpgquery: unique_ptr Transformer::TransformOnConflictClause(duckdb_libpgquery::PGOnConflictClause *node, const string &) { - auto stmt = PGPointerCast(node); D_ASSERT(stmt); diff --git a/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp b/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp index 14205959a..3157238d0 100644 --- a/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp +++ b/src/duckdb/src/parser/transform/tableref/transform_base_tableref.cpp @@ -43,7 +43,7 @@ QualifiedName Transformer::TransformQualifiedName(duckdb_libpgquery::PGRangeVar if (root.relname) { qname.name = root.relname; } else { - qname.name = string(); + throw ParserException("Empty table name not supported"); } return qname; } diff --git a/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp b/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp index d042d881d..dcb2dc036 100644 --- a/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp +++ b/src/duckdb/src/parser/transform/tableref/transform_pivot.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/parser/tableref/pivotref.hpp" #include "duckdb/parser/transformer.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" diff --git a/src/duckdb/src/parser/transformer.cpp b/src/duckdb/src/parser/transformer.cpp index 4ab39fca7..f3f058899 100644 --- a/src/duckdb/src/parser/transformer.cpp +++ b/src/duckdb/src/parser/transformer.cpp @@ -2,9 +2,7 @@ #include "duckdb/parser/expression/list.hpp" #include "duckdb/parser/statement/list.hpp" -#include "duckdb/parser/tableref/emptytableref.hpp" #include "duckdb/parser/query_node/select_node.hpp" -#include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/parser_options.hpp" namespace duckdb { @@ -232,31 +230,6 @@ unique_ptr Transformer::TransformStatementInternal(duckdb_libpgque } } -unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - - for (auto &cte : root->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = TransformMaterializedCTE(cte_entry->query->node->Copy()); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return root; -} - void Transformer::SetQueryLocation(ParsedExpression &expr, int query_location) { if (query_location < 0) { return; diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index b6e5df81f..cc1b3d25e 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -38,16 +38,17 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) optional_ptr result; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; - auto is_using_binding = GetUsingBinding(column_name, binding.alias); + auto is_using_binding = GetUsingBinding(column_name, binding.GetBindingAlias()); if (is_using_binding) { continue; } if (binding.HasMatchingBinding(column_name)) { if (result || is_using_binding) { - throw BinderException("Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " - "or \"%s.%s\")", - column_name, MinimumUniqueAlias(result->alias, binding.alias), column_name, - MinimumUniqueAlias(binding.alias, result->alias), column_name); + throw BinderException( + "Ambiguous reference to column name \"%s\" (use: \"%s.%s\" " + "or \"%s.%s\")", + column_name, MinimumUniqueAlias(result->GetBindingAlias(), binding.GetBindingAlias()), column_name, + MinimumUniqueAlias(binding.GetBindingAlias(), result->GetBindingAlias()), column_name); } result = &binding; } @@ -58,8 +59,8 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name) vector BindContext::GetSimilarBindings(const string &column_name) { vector> scores; for (auto &binding_ptr : bindings_list) { - auto binding = *binding_ptr; - for (auto &name : binding.names) { + auto &binding = *binding_ptr; + for (auto &name : binding.GetColumnNames()) { double distance = StringUtil::SimilarityRating(name, column_name); // check if we need to qualify the column auto matching_bindings = GetMatchingBindings(name); @@ -77,10 +78,6 @@ void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set using_columns[column_name].insert(set); } -void BindContext::AddUsingBindingSet(unique_ptr set) { - using_column_sets.push_back(std::move(set)); -} - optional_ptr BindContext::GetUsingBinding(const string &column_name) { auto entry = using_columns.find(column_name); if (entry == using_columns.end()) { @@ -161,7 +158,7 @@ string BindContext::GetActualColumnName(Binding &binding, const string &column_n throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding.GetAlias(), column_name); } // LCOV_EXCL_STOP - return binding.names[binding_index]; + return binding.GetColumnNames()[binding_index]; } string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const string &column_name) { @@ -204,7 +201,7 @@ unique_ptr BindContext::CreateColumnReference(const string &ta } static bool ColumnIsGenerated(Binding &binding, column_t index) { - if (binding.binding_type != BindingType::TABLE) { + if (binding.GetBindingType() != BindingType::TABLE) { return false; } auto &table_binding = binding.Cast(); @@ -243,10 +240,12 @@ unique_ptr BindContext::CreateColumnReference(const string &ca auto column_index = binding->GetBindingIndex(column_name); if (bind_type == ColumnBindType::EXPAND_GENERATED_COLUMNS && ColumnIsGenerated(*binding, column_index)) { return ExpandGeneratedColumn(binding->Cast(), column_name); - } else if (column_index < binding->names.size() && binding->names[column_index] != column_name) { + } + auto &column_names = binding->GetColumnNames(); + if (column_index < column_names.size() && column_names[column_index] != column_name) { // because of case insensitivity in the binder we rename the column to the original name // as it appears in the binding itself - result->SetAlias(binding->names[column_index]); + result->SetAlias(column_names[column_index]); } return std::move(result); } @@ -257,14 +256,6 @@ unique_ptr BindContext::CreateColumnReference(const string &sc return CreateColumnReference(catalog_name, schema_name, table_name, column_name, bind_type); } -optional_ptr BindContext::GetCTEBinding(const string &ctename) { - auto match = cte_bindings.find(ctename); - if (match == cte_bindings.end()) { - return nullptr; - } - return match->second.get(); -} - string GetCandidateAlias(const BindingAlias &main_alias, const BindingAlias &new_alias) { string candidate; if (!main_alias.GetCatalog().empty() && !new_alias.GetCatalog().empty()) { @@ -283,7 +274,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E } vector> matching_bindings; for (auto &binding : bindings_list) { - if (binding->alias.Matches(alias)) { + if (binding->GetBindingAlias().Matches(alias)) { matching_bindings.push_back(*binding); } } @@ -291,7 +282,7 @@ vector> BindContext::GetBindings(const BindingAlias &alias, E // alias not found in this BindContext vector candidates; for (auto &binding : bindings_list) { - candidates.push_back(GetCandidateAlias(alias, binding->alias)); + candidates.push_back(GetCandidateAlias(alias, binding->GetBindingAlias())); } auto main_alias = GetCandidateAlias(alias, alias); string candidate_str = @@ -315,14 +306,14 @@ string BindContext::AmbiguityException(const BindingAlias &alias, const vector handled_using_columns; for (auto &entry : bindings_list) { auto &binding = *entry; - for (auto &column_name : binding.names) { - QualifiedColumnName qualified_column(binding.alias, column_name); + auto &column_names = binding.GetColumnNames(); + auto &binding_alias = binding.GetBindingAlias(); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_column(binding_alias, column_name); if (CheckExclusionList(expr, qualified_column, exclusion_info)) { continue; } // check if this column is a USING column - auto using_binding_ptr = GetUsingBinding(column_name, binding.alias); + auto using_binding_ptr = GetUsingBinding(column_name, binding_alias); if (using_binding_ptr) { auto &using_binding = *using_binding_ptr; // it is! @@ -530,7 +524,7 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, continue; } auto new_expr = - CreateColumnReference(binding.alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_column, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -548,17 +542,20 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, } is_struct_ref = true; } + auto &binding_alias = binding->GetBindingAlias(); + auto &column_names = binding->GetColumnNames(); + auto &column_types = binding->GetColumnTypes(); if (is_struct_ref) { auto col_idx = binding->GetBindingIndex(expr.relation_name); - auto col_type = binding->types[col_idx]; + auto col_type = column_types[col_idx]; if (col_type.id() != LogicalTypeId::STRUCT) { throw BinderException(StringUtil::Format( "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); } auto &struct_children = StructType::GetChildTypes(col_type); vector column_names(3); - column_names[0] = binding->alias.GetAlias(); + column_names[0] = binding->GetAlias(); column_names[1] = expr.relation_name; for (auto &child : struct_children) { QualifiedColumnName qualified_name(child.first); @@ -571,13 +568,13 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, new_select_list.push_back(std::move(new_expr)); } } else { - for (auto &column_name : binding->names) { - QualifiedColumnName qualified_name(binding->alias, column_name); + for (auto &column_name : column_names) { + QualifiedColumnName qualified_name(binding_alias, column_name); if (CheckExclusionList(expr, qualified_name, exclusion_info)) { continue; } auto new_expr = - CreateColumnReference(binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + CreateColumnReference(binding_alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); HandleRename(expr, qualified_name, *new_expr); new_select_list.push_back(std::move(new_expr)); } @@ -613,10 +610,12 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { for (auto &binding_entry : bindings_list) { auto &binding = *binding_entry; - D_ASSERT(binding.names.size() == binding.types.size()); - for (idx_t i = 0; i < binding.names.size(); i++) { - result_names.push_back(binding.names[i]); - result_types.push_back(binding.types[i]); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); + for (idx_t i = 0; i < column_names.size(); i++) { + result_names.push_back(column_names[i]); + result_types.push_back(column_types[i]); } } } @@ -686,7 +685,7 @@ vector BindContext::AliasColumnNames(const string &table_name, const vec return result; } -void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } @@ -696,13 +695,13 @@ void BindContext::AddEntryBinding(idx_t index, const string &alias, const vector AddBinding(make_uniq(alias, types, names, index, entry)); } -void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, +void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddEntryBinding(index, alias, names, subquery.types, view.Cast()); } -void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } @@ -712,33 +711,28 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect AddBinding(make_uniq(BindingType::BASE, BindingAlias(alias), types, names, index)); } -void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, - const vector &types, bool using_key) { - auto binding = make_shared_ptr(BindingType::BASE, BindingAlias(alias), types, names, index); - - if (cte_bindings.find(alias) != cte_bindings.end()) { - throw BinderException("Duplicate CTE binding \"%s\" in query!", alias); +void BindContext::AddCTEBinding(unique_ptr binding) { + for (auto &cte_binding : cte_bindings) { + if (cte_binding->GetBindingAlias() == binding->GetBindingAlias()) { + throw BinderException("Duplicate CTE binding \"%s\" in query!", binding->GetBindingAlias().ToString()); + } } - cte_bindings[alias] = std::move(binding); - cte_references[alias] = make_shared_ptr(0); + cte_bindings.push_back(std::move(binding)); +} - if (using_key) { - auto recurring_alias = "recurring." + alias; - cte_bindings[recurring_alias] = - make_shared_ptr(BindingType::BASE, BindingAlias(recurring_alias), types, names, index); - cte_references[recurring_alias] = make_shared_ptr(0); - } +void BindContext::AddCTEBinding(idx_t index, BindingAlias alias_p, const vector &names, + const vector &types, CTEType cte_type) { + auto binding = make_uniq(std::move(alias_p), types, names, index, cte_type); + AddCTEBinding(std::move(binding)); } -void BindContext::RemoveCTEBinding(const std::string &alias) { - auto it = cte_bindings.find(alias); - if (it != cte_bindings.end()) { - cte_bindings.erase(it); - } - auto it2 = cte_references.find(alias); - if (it2 != cte_references.end()) { - cte_references.erase(it2); +optional_ptr BindContext::GetCTEBinding(const BindingAlias &ctename) { + for (auto &binding : cte_bindings) { + if (binding->GetBindingAlias().Matches(ctename)) { + return binding.get(); + } } + return nullptr; } void BindContext::AddContext(BindContext other) { @@ -755,7 +749,7 @@ void BindContext::AddContext(BindContext other) { vector BindContext::GetBindingAliases() { vector result; for (auto &binding : bindings_list) { - result.push_back(BindingAlias(binding->alias)); + result.push_back(binding->GetBindingAlias()); } return result; } @@ -782,7 +776,7 @@ void BindContext::RemoveContext(const vector &aliases) { // remove the binding from the list of bindings auto it = std::remove_if(bindings_list.begin(), bindings_list.end(), - [&](unique_ptr &x) { return x->alias == alias; }); + [&](unique_ptr &x) { return x->GetBindingAlias() == alias; }); bindings_list.erase(it, bindings_list.end()); } } diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 2ba52b64f..fe1b59cf3 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -28,10 +28,6 @@ namespace duckdb { -Binder &Binder::GetRootBinder() { - return root_binder; -} - idx_t Binder::GetBinderDepth() const { return depth; } @@ -50,9 +46,11 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent_p, BinderType binder_type) - : context(context), bind_context(*this), parent(std::move(parent_p)), bound_tables(0), binder_type(binder_type), - entry_retriever(context), root_binder(parent ? parent->GetRootBinder() : *this), - depth(parent ? parent->GetBinderDepth() : 1) { + : context(context), bind_context(*this), parent(std::move(parent_p)), binder_type(binder_type), + global_binder_state(parent ? parent->global_binder_state : make_shared_ptr()), + query_binder_state(parent && binder_type == BinderType::REGULAR_BINDER ? parent->query_binder_state + : make_shared_ptr()), + entry_retriever(context), depth(parent ? parent->GetBinderDepth() : 1) { IncreaseDepth(); if (parent) { entry_retriever.Inherit(parent->entry_retriever); @@ -60,85 +58,22 @@ Binder::Binder(ClientContext &context, shared_ptr parent_p, BinderType b // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. macro_binding = parent->macro_binding; lambda_bindings = parent->lambda_bindings; - - if (binder_type == BinderType::REGULAR_BINDER) { - // We have to inherit CTE bindings from the parent bind_context, if there is a parent. - bind_context.SetCTEBindings(parent->bind_context.GetCTEBindings()); - bind_context.cte_references = parent->bind_context.cte_references; - parameters = parent->parameters; - } - } -} - -unique_ptr Binder::BindMaterializedCTE(CommonTableExpressionMap &cte_map) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - for (auto &cte : cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - if (materialized_ctes.empty()) { - return nullptr; - } - - unique_ptr cte_root = nullptr; - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = cte_map.Copy(); - if (cte_root) { - node_result->child = std::move(cte_root); - } else { - node_result->child = nullptr; - } - cte_root = std::move(node_result); - materialized_ctes.pop_back(); } - - AddCTEMap(cte_map); - auto bound_cte = BindCTE(cte_root->Cast()); - - return bound_cte; } template BoundStatement Binder::BindWithCTE(T &statement) { - BoundStatement bound_statement; - auto bound_cte = BindMaterializedCTE(statement.template Cast().cte_map); - if (bound_cte) { - reference tail_ref = *bound_cte; - - while (tail_ref.get().child && tail_ref.get().child->type == QueryNodeType::CTE_NODE) { - tail_ref = tail_ref.get().child->Cast(); - } - - auto &tail = tail_ref.get(); - bound_statement = tail.child_binder->Bind(statement.template Cast()); - - tail.types = bound_statement.types; - tail.names = bound_statement.names; - - for (auto &c : tail.query_binder->correlated_columns) { - tail.child_binder->AddCorrelatedColumn(c); - } - MoveCorrelatedExpressions(*tail.child_binder); - - auto plan = std::move(bound_statement.plan); - bound_statement.plan = CreatePlan(*bound_cte, std::move(plan)); - } else { - bound_statement = Bind(statement.template Cast()); + auto &cte_map = statement.cte_map; + if (cte_map.map.empty()) { + return Bind(statement); } - return bound_statement; + + auto stmt_node = make_uniq(statement); + stmt_node->cte_map = cte_map.Copy(); + return Bind(*stmt_node); } BoundStatement Binder::Bind(SQLStatement &statement) { - root_statement = &statement; switch (statement.type) { case StatementType::SELECT_STATEMENT: return Bind(statement.Cast()); @@ -198,64 +133,12 @@ BoundStatement Binder::Bind(SQLStatement &statement) { } // LCOV_EXCL_STOP } -void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { - for (auto &cte_it : cte_map.map) { - AddCTE(cte_it.first); - } -} - -unique_ptr Binder::BindNode(QueryNode &node) { - // first we visit the set of CTEs and add them to the bind context - AddCTEMap(node.cte_map); - // now we bind the node - unique_ptr result; - switch (node.type) { - case QueryNodeType::SELECT_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::RECURSIVE_CTE_NODE: - result = BindNode(node.Cast()); - break; - case QueryNodeType::CTE_NODE: - result = BindNode(node.Cast()); - break; - default: - D_ASSERT(node.type == QueryNodeType::SET_OPERATION_NODE); - result = BindNode(node.Cast()); - break; - } - return result; -} - BoundStatement Binder::Bind(QueryNode &node) { - BoundStatement result; - auto bound_node = BindNode(node); - - result.names = bound_node->names; - result.types = bound_node->types; - - // and plan it - result.plan = CreatePlan(*bound_node); - return result; + return BindNode(node); } -unique_ptr Binder::CreatePlan(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::CTE_NODE: - return CreatePlan(node.Cast()); - default: - throw InternalException("Unsupported bound query node type"); - } -} - -unique_ptr Binder::Bind(TableRef &ref) { - unique_ptr result; +BoundStatement Binder::Bind(TableRef &ref) { + BoundStatement result; switch (ref.type) { case TableReferenceType::BASE_TABLE: result = Bind(ref.Cast()); @@ -295,80 +178,33 @@ unique_ptr Binder::Bind(TableRef &ref) { default: throw InternalException("Unknown table ref type (%s)", EnumUtil::ToString(ref.type)); } - result->sample = std::move(ref.sample); - return result; -} - -unique_ptr Binder::CreatePlan(BoundTableRef &ref) { - unique_ptr root; - switch (ref.type) { - case TableReferenceType::BASE_TABLE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::SUBQUERY: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::JOIN: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::TABLE_FUNCTION: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EMPTY_FROM: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EXPRESSION_LIST: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::COLUMN_DATA: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::CTE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::PIVOT: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::DELIM_GET: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::INVALID: - default: - throw InternalException("Unsupported bound table ref type (%s)", EnumUtil::ToString(ref.type)); - } - // plan the sample clause if (ref.sample) { - root = make_uniq(std::move(ref.sample), std::move(root)); - } - return root; -} - -void Binder::AddCTE(const string &name) { - D_ASSERT(!name.empty()); - CTE_bindings.insert(name); -} - -vector> Binder::FindCTE(const string &name, bool skip) { - auto entry = bind_context.GetCTEBinding(name); - vector> ctes; - if (entry) { - ctes.push_back(*entry.get()); - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - auto parent_ctes = parent->FindCTE(name, name == alias); - ctes.insert(ctes.end(), parent_ctes.begin(), parent_ctes.end()); + result.plan = make_uniq(std::move(ref.sample), std::move(result.plan)); } - return ctes; + return result; } -bool Binder::CTEExists(const string &name) { - if (CTE_bindings.find(name) != CTE_bindings.end()) { - return true; - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - return parent->CTEExists(name); +optional_ptr Binder::GetCTEBinding(const BindingAlias &name) { + reference current_binder(*this); + optional_ptr result; + while (true) { + auto ¤t = current_binder.get(); + auto entry = current.bind_context.GetCTEBinding(name); + if (entry) { + // we only directly return the CTE if it can be referenced + // if it cannot be referenced (circular reference) we keep going up the stack + // to look for a CTE that can be referenced + if (entry->CanBeReferenced()) { + return entry; + } + result = entry; + } + if (!current.parent || current.binder_type != BinderType::REGULAR_BINDER) { + break; + } + current_binder = *current.parent; } - return false; + return result; } void Binder::AddBoundView(ViewCatalogEntry &view) { @@ -384,13 +220,19 @@ void Binder::AddBoundView(ViewCatalogEntry &view) { } idx_t Binder::GenerateTableIndex() { - auto &root_binder = GetRootBinder(); - return root_binder.bound_tables++; + return global_binder_state->bound_tables++; } StatementProperties &Binder::GetStatementProperties() { - auto &root_binder = GetRootBinder(); - return root_binder.prop; + return global_binder_state->prop; +} + +optional_ptr Binder::GetParameters() { + return global_binder_state->parameters; +} + +void Binder::SetParameters(BoundParameterMap ¶meters) { + global_binder_state->parameters = parameters; } void Binder::PushExpressionBinder(ExpressionBinder &binder) { @@ -416,17 +258,11 @@ bool Binder::HasActiveBinder() { } vector> &Binder::GetActiveBinders() { - reference root = *this; - while (root.get().parent && root.get().binder_type == BinderType::REGULAR_BINDER) { - root = *root.get().parent; - } - auto &root_binder = root.get(); - return root_binder.active_binders; + return query_binder_state->active_binders; } void Binder::AddUsingBindingSet(unique_ptr set) { - auto &root_binder = GetRootBinder(); - root_binder.bind_context.AddUsingBindingSet(std::move(set)); + global_binder_state->using_column_sets.push_back(std::move(set)); } void Binder::MoveCorrelatedExpressions(Binder &other) { @@ -434,7 +270,7 @@ void Binder::MoveCorrelatedExpressions(Binder &other) { other.correlated_columns.clear(); } -void Binder::MergeCorrelatedColumns(vector &other) { +void Binder::MergeCorrelatedColumns(CorrelatedColumns &other) { for (idx_t i = 0; i < other.size(); i++) { AddCorrelatedColumn(other[i]); } @@ -443,7 +279,7 @@ void Binder::MergeCorrelatedColumns(vector &other) { void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { // we only add correlated columns to the list if they are not already there if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(info); + correlated_columns.AddColumn(info); } } @@ -463,7 +299,6 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con const string &table_name, const string &column_name, ErrorData &error) { optional_ptr binding; - D_ASSERT(!lambda_bindings); if (macro_binding && table_name == macro_binding->GetAlias()) { binding = optional_ptr(macro_binding.get()); } else { @@ -474,13 +309,11 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con } void Binder::SetBindingMode(BindingMode mode) { - auto &root_binder = GetRootBinder(); - root_binder.mode = mode; + global_binder_state->mode = mode; } BindingMode Binder::GetBindingMode() { - auto &root_binder = GetRootBinder(); - return root_binder.mode; + return global_binder_state->mode; } void Binder::SetCanContainNulls(bool can_contain_nulls_p) { @@ -493,30 +326,26 @@ void Binder::SetAlwaysRequireRebind() { } void Binder::AddTableName(string table_name) { - auto &root_binder = GetRootBinder(); - root_binder.table_names.insert(std::move(table_name)); + global_binder_state->table_names.insert(std::move(table_name)); } void Binder::AddReplacementScan(const string &table_name, unique_ptr replacement) { - auto &root_binder = GetRootBinder(); - auto it = root_binder.replacement_scans.find(table_name); + auto it = global_binder_state->replacement_scans.find(table_name); replacement->column_name_alias.clear(); replacement->alias.clear(); - if (it == root_binder.replacement_scans.end()) { - root_binder.replacement_scans[table_name] = std::move(replacement); + if (it == global_binder_state->replacement_scans.end()) { + global_binder_state->replacement_scans[table_name] = std::move(replacement); } else { // A replacement scan by this name was previously registered, we can just use it } } const unordered_set &Binder::GetTableNames() { - auto &root_binder = GetRootBinder(); - return root_binder.table_names; + return global_binder_state->table_names; } case_insensitive_map_t> &Binder::GetReplacementScans() { - auto &root_binder = GetRootBinder(); - return root_binder.replacement_scans; + return global_binder_state->replacement_scans; } // FIXME: this is extremely naive @@ -537,7 +366,6 @@ void VerifyNotExcluded(const ParsedExpression &root_expr) { BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, const string &alias, idx_t update_table_index, unique_ptr child_operator, virtual_column_map_t virtual_columns) { - vector types; vector names; @@ -582,7 +410,7 @@ BoundStatement Binder::BindReturning(vector> return // returned, it should be guaranteed that the row has been inserted. // see https://github.com/duckdb/duckdb/issues/8310 auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::QUERY_RESULT; return result; } diff --git a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp index 09c92dd48..fe7e34e53 100644 --- a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp @@ -32,7 +32,6 @@ BindResult ExpressionBinder::BindExpression(BetweenExpression &expr, idx_t depth LogicalType input_type; if (!BoundComparisonExpression::TryBindComparison(context, input_sql_type, lower_sql_type, input_type, expr.GetExpressionType())) { - throw BinderException(expr, "Cannot mix values of type %s and %s in BETWEEN clause - an explicit cast is required", input_sql_type.ToString(), lower_sql_type.ToString()); diff --git a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp index 886a1ff42..ca6824599 100644 --- a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp @@ -94,12 +94,12 @@ unique_ptr ExpressionBinder::QualifyColumnName(const string &c // bind as a macro column if (is_macro_column) { - return binder.bind_context.CreateColumnReference(binder.macro_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(binder.macro_binding->GetBindingAlias(), column_name); } // bind as a regular column if (table_binding) { - return binder.bind_context.CreateColumnReference(table_binding->alias, column_name); + return binder.bind_context.CreateColumnReference(table_binding->GetBindingAlias(), column_name); } // it's not, find candidates and error @@ -111,7 +111,6 @@ unique_ptr ExpressionBinder::QualifyColumnName(const string &c void ExpressionBinder::QualifyColumnNames(unique_ptr &expr, vector> &lambda_params, const bool within_function_expression) { - bool next_within_function_expression = false; switch (expr->GetExpressionType()) { case ExpressionType::COLUMN_REF: { @@ -177,7 +176,6 @@ void ExpressionBinder::QualifyColumnNames(unique_ptr &expr, void ExpressionBinder::QualifyColumnNamesInLambda(FunctionExpression &function, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { // not a lambda expression @@ -228,7 +226,6 @@ void ExpressionBinder::QualifyColumnNames(ExpressionBinder &expression_binder, u unique_ptr ExpressionBinder::CreateStructExtract(unique_ptr base, const string &field_name) { - vector> children; children.push_back(std::move(base)); children.push_back(make_uniq_base(Value(field_name))); @@ -276,11 +273,12 @@ unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpress } // We found the table, now create the struct_pack expression + auto &column_names = binding->GetColumnNames(); vector> child_expressions; - child_expressions.reserve(binding->names.size()); - for (const auto &column_name : binding->names) { + child_expressions.reserve(column_names.size()); + for (const auto &column_name : column_names) { child_expressions.push_back(binder.bind_context.CreateColumnReference( - binding->alias, column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); + binding->GetBindingAlias(), column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); } return make_uniq("struct_pack", std::move(child_expressions)); } @@ -312,7 +310,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.schema.table.column" struct_extract_start = 4; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[3]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[3]); } } ErrorData catalog_table_error; @@ -321,7 +319,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte if (binding) { // part1 is a catalog - the column reference is "catalog.table.column" struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData schema_table_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2], @@ -330,7 +328,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // part1 is a schema - the column reference is "schema.table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[2]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); } ErrorData table_column_error; binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], table_column_error); @@ -339,7 +337,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte // the column reference is "table.column" // any additional fields are turned into struct_extract calls struct_extract_start = 2; - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.column_names[1]); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[1]); } // part1 could be a column ErrorData unused_error; @@ -360,7 +358,7 @@ unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInte optional_idx schema_pos; optional_idx table_pos; for (const auto &binding_entry : binder.bind_context.GetBindingsList()) { - auto &alias = binding_entry->alias; + auto &alias = binding_entry->GetBindingAlias(); string catalog = alias.GetCatalog(); string schema = alias.GetSchema(); string table = alias.GetAlias(); @@ -483,7 +481,7 @@ unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpres auto binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], error); if (binding) { // it is! return the column reference directly - return binder.bind_context.CreateColumnReference(binding->alias, col_ref.GetColumnName()); + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.GetColumnName()); } // otherwise check if we can turn this into a struct extract @@ -506,7 +504,8 @@ BindResult ExpressionBinder::BindExpression(LambdaRefExpression &lambda_ref, idx return (*lambda_bindings)[lambda_ref.lambda_idx].Bind(lambda_ref, depth); } -BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_t depth, bool root_expression) { +BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_t depth, bool root_expression, + unique_ptr &expr_ptr) { if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES || binder.GetBindingMode() == BindingMode::EXTRACT_QUALIFIED_NAMES) { return BindResult(make_uniq(Value(LogicalType::SQLNULL))); @@ -515,21 +514,17 @@ BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_ ErrorData error; auto expr = QualifyColumnName(col_ref_p, error); if (!expr) { - if (!col_ref_p.IsQualified()) { - // column was not found - // first try to bind it as an alias + // column wasn't found + if (ExpressionBinder::IsPotentialAlias(col_ref_p)) { BindResult alias_result; - auto found_alias = TryBindAlias(col_ref_p, root_expression, alias_result); + auto found_alias = TryResolveAliasReference(col_ref_p, depth, root_expression, alias_result, expr_ptr); if (found_alias) { return alias_result; } - found_alias = QualifyColumnAlias(col_ref_p); - if (!found_alias) { - // column was not found - check if it is a SQL value function - auto value_function = GetSQLValueFunction(col_ref_p.GetColumnName()); - if (value_function) { - return BindExpression(value_function, depth); - } + + auto value_function = GetSQLValueFunction(col_ref_p.GetColumnName()); + if (value_function) { + return BindExpression(value_function, depth); } } error.AddQueryLocation(col_ref_p); @@ -577,9 +572,4 @@ BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_ return result; } -bool ExpressionBinder::QualifyColumnAlias(const ColumnRefExpression &col_ref) { - // only the BaseSelectBinder will have a valid column alias map, - // otherwise we return false - return false; -} } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp index e0d775db1..92a5d0757 100644 --- a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp @@ -27,6 +27,7 @@ BindResult ExpressionBinder::TryBindLambdaOrJson(FunctionExpression &function, i auto setting = config.lambda_syntax; bool invalid_syntax = setting == LambdaSyntax::DISABLE_SINGLE_ARROW && syntax_type == LambdaSyntaxType::SINGLE_ARROW; + bool warn_deprecated_syntax = setting == LambdaSyntax::DEFAULT && syntax_type == LambdaSyntaxType::SINGLE_ARROW; const string msg = "Deprecated lambda arrow (->) detected. Please transition to the new lambda syntax, " "i.e.., lambda x, i: x + i, before DuckDB's next release. \n" "Use SET lambda_syntax='ENABLE_SINGLE_ARROW' to revert to the deprecated behavior. \n" @@ -49,11 +50,18 @@ BindResult ExpressionBinder::TryBindLambdaOrJson(FunctionExpression &function, i if (!lambda_bind_result.HasError()) { if (!invalid_syntax) { + if (warn_deprecated_syntax) { + DUCKDB_LOG_WARNING(context, msg); + } return lambda_bind_result; } return BindResult(msg); } if (StringUtil::Contains(lambda_bind_result.error.RawMessage(), "Deprecated lambda arrow (->) detected.")) { + if (warn_deprecated_syntax) { + DUCKDB_LOG_WARNING(context, msg); + } + return lambda_bind_result; } @@ -107,7 +115,7 @@ optional_ptr ExpressionBinder::BindAndQualifyFunction(FunctionExpr auto new_colref = QualifyColumnName(*colref, error); if (error.HasError()) { // could not find the column - try to qualify the alias - if (!QualifyColumnAlias(*colref)) { + if (!DoesColumnAliasExist(*colref)) { if (!allow_throw) { return func; } @@ -195,7 +203,7 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu } if (result->GetExpressionType() == ExpressionType::BOUND_FUNCTION) { auto &bound_function = result->Cast(); - if (bound_function.function.stability == FunctionStability::CONSISTENT_WITHIN_QUERY) { + if (bound_function.function.GetStability() == FunctionStability::CONSISTENT_WITHIN_QUERY) { binder.SetAlwaysRequireRebind(); } } @@ -204,10 +212,9 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, ScalarFunctionCatalogEntry &func, idx_t depth) { - // get the callback function for the lambda parameter types auto &scalar_function = func.functions.functions.front(); - auto &bind_lambda_function = scalar_function.bind_lambda; + auto bind_lambda_function = scalar_function.GetBindLambdaCallback(); if (!bind_lambda_function) { return BindResult("This scalar function does not support lambdas!"); } @@ -302,13 +309,14 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc idx_t offset = 0; if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { - auto &binding = (*lambda_bindings)[i - 1]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); - for (idx_t column_idx = binding.names.size(); column_idx > 0; column_idx--) { - auto bound_lambda_param = make_uniq(binding.names[column_idx - 1], - binding.types[column_idx - 1], offset); + for (idx_t column_idx = column_names.size(); column_idx > 0; column_idx--) { + auto bound_lambda_param = make_uniq(column_names[column_idx - 1], + column_types[column_idx - 1], offset); offset++; bound_function_expr.children.push_back(std::move(bound_lambda_param)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp index 592daa245..0d6334fc4 100644 --- a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp @@ -12,31 +12,30 @@ namespace duckdb { -idx_t GetLambdaParamCount(const vector &lambda_bindings) { +idx_t GetLambdaParamCount(vector &lambda_bindings) { idx_t count = 0; for (auto &binding : lambda_bindings) { - count += binding.names.size(); + count += binding.GetColumnCount(); } return count; } -idx_t GetLambdaParamIndex(const vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, +idx_t GetLambdaParamIndex(vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, const BoundLambdaRefExpression &bound_lambda_ref_expr) { D_ASSERT(bound_lambda_ref_expr.lambda_idx < lambda_bindings.size()); idx_t offset = 0; // count the remaining lambda parameters BEFORE the current lambda parameter, // as these will be in front of the current lambda parameter in the input chunk for (idx_t i = bound_lambda_ref_expr.lambda_idx + 1; i < lambda_bindings.size(); i++) { - offset += lambda_bindings[i].names.size(); + offset += lambda_bindings[i].GetColumnCount(); } - offset += - lambda_bindings[bound_lambda_ref_expr.lambda_idx].names.size() - bound_lambda_ref_expr.binding.column_index - 1; + offset += lambda_bindings[bound_lambda_ref_expr.lambda_idx].GetColumnCount() - + bound_lambda_ref_expr.binding.column_index - 1; offset += bound_lambda_expr.parameter_count; return offset; } void ExtractParameter(const ParsedExpression &expr, vector &column_names, vector &column_aliases) { - auto &column_ref = expr.Cast(); if (column_ref.IsQualified()) { throw BinderException(LambdaExpression::InvalidParametersErrorMessage()); @@ -47,7 +46,6 @@ void ExtractParameter(const ParsedExpression &expr, vector &column_names } void ExtractParameters(LambdaExpression &expr, vector &column_names, vector &column_aliases) { - // extract the lambda parameters, which are a single column // reference, or a list of column references (ROW function) string error_message; @@ -136,28 +134,26 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori BoundLambdaExpression &bound_lambda_expr, const optional_ptr bind_lambda_function, const vector &function_child_types) { - // check if the original expression is a lambda parameter if (original->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { - auto &bound_lambda_ref = original->Cast(); auto alias = bound_lambda_ref.GetAlias(); // refers to a lambda parameter outside the current lambda function // so the lambda parameter will be inside the lambda_bindings if (lambda_bindings && bound_lambda_ref.lambda_idx != lambda_bindings->size()) { - auto &binding = (*lambda_bindings)[bound_lambda_ref.lambda_idx]; - D_ASSERT(binding.names.size() == binding.types.size()); + auto &column_names = binding.GetColumnNames(); + auto &column_types = binding.GetColumnTypes(); + D_ASSERT(column_names.size() == column_types.size()); // find the matching dummy column in the lambda binding - for (idx_t column_idx = 0; column_idx < binding.names.size(); column_idx++) { + for (idx_t column_idx = 0; column_idx < binding.GetColumnCount(); column_idx++) { if (column_idx == bound_lambda_ref.binding.column_index) { - // now create the replacement auto index = GetLambdaParamIndex(*lambda_bindings, bound_lambda_expr, bound_lambda_ref); - replacement = make_uniq(binding.names[column_idx], - binding.types[column_idx], index); + replacement = + make_uniq(column_names[column_idx], column_types[column_idx], index); return; } } @@ -188,7 +184,6 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_expr, unique_ptr &expr, const optional_ptr bind_lambda_function, const vector &function_child_types) { - if (expr->GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY) { throw BinderException("subqueries in lambda expressions are not supported"); } @@ -206,7 +201,6 @@ void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_ if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF || expr->GetExpressionClass() == ExpressionClass::BOUND_PARAMETER || expr->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { - if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { // Search for UNNEST. auto &column_binding = expr->Cast().binding; diff --git a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp index cce06d712..fccf527ff 100644 --- a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp @@ -11,7 +11,6 @@ namespace duckdb { void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &function, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { ReplaceMacroParameters(child, lambda_params); @@ -47,7 +46,6 @@ void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &functi void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr, vector> &lambda_params) { - switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { // If the expression is a column reference, we replace it with its argument. @@ -98,6 +96,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala // validate the arguments and separate positional and default arguments vector> positional_arguments; InsertionOrderPreservingMap> named_arguments; + binder.lambda_bindings = lambda_bindings; auto bind_result = MacroFunction::BindMacroFunction(binder, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); if (!bind_result.error.empty()) { diff --git a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp index a0967e9ef..74b3477ee 100644 --- a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp @@ -8,6 +8,7 @@ #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/expression_binder/try_operator_binder.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" @@ -92,16 +93,25 @@ BindResult ExpressionBinder::BindGroupingFunction(OperatorExpression &op, idx_t } BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) { - if (op.GetExpressionType() == ExpressionType::GROUPING_FUNCTION) { + auto operator_type = op.GetExpressionType(); + if (operator_type == ExpressionType::GROUPING_FUNCTION) { return BindGroupingFunction(op, depth); } // Bind the children of the operator expression. We already create bound expressions. // Only those children that trigger an error are not yet bound. ErrorData error; - for (idx_t i = 0; i < op.children.size(); i++) { - BindChild(op.children[i], depth, error); + if (operator_type == ExpressionType::OPERATOR_TRY) { + D_ASSERT(op.children.size() == 1); + //! This binder is used to throw when the child expression is of a type that is not allowed. + TryOperatorBinder try_operator_binder(binder, context); + try_operator_binder.BindChild(op.children[0], depth, error); + } else { + for (idx_t i = 0; i < op.children.size(); i++) { + BindChild(op.children[i], depth, error); + } } + if (error.HasError()) { return BindResult(std::move(error)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp index 109c0ecbd..3fe02467e 100644 --- a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -8,19 +8,19 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { - if (!binder.parameters) { + auto parameters = binder.GetParameters(); + if (!parameters) { throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); } auto parameter_id = expr.identifier; - D_ASSERT(binder.parameters); // Check if a parameter value has already been supplied - auto ¶meter_data = binder.parameters->GetParameterData(); + auto ¶meter_data = parameters->GetParameterData(); auto param_data_it = parameter_data.find(parameter_id); if (param_data_it != parameter_data.end()) { // it has! emit a constant directly auto &data = param_data_it->second; - auto return_type = binder.parameters->GetReturnType(parameter_id); + auto return_type = parameters->GetReturnType(parameter_id); bool is_literal = return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL; auto constant = make_uniq(data.GetValue()); @@ -32,7 +32,7 @@ BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t dep return BindResult(std::move(cast)); } - auto bound_parameter = binder.parameters->BindParameterExpression(expr); + auto bound_parameter = parameters->BindParameterExpression(expr); return BindResult(std::move(bound_parameter)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp index f48fc14e6..45f8ae07a 100644 --- a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp @@ -152,10 +152,15 @@ string Binder::ReplaceColumnsAlias(const string &alias, const string &column_nam void TryTransformStarLike(unique_ptr &root) { // detect "* LIKE [literal]" and similar expressions - if (root->GetExpressionClass() != ExpressionClass::FUNCTION) { + bool inverse = root->GetExpressionType() == ExpressionType::OPERATOR_NOT; + auto &expr = inverse ? root->Cast().children[0] : root; + if (!expr) { return; } - auto &function = root->Cast(); + if (expr->GetExpressionClass() != ExpressionClass::FUNCTION) { + return; + } + auto &function = expr->Cast(); if (function.children.size() < 2 || function.children.size() > 3) { return; } @@ -197,7 +202,7 @@ void TryTransformStarLike(unique_ptr &root) { auto original_alias = root->GetAlias(); auto star_expr = std::move(left); unique_ptr child_expr; - if (function.function_name == "regexp_full_match" && star.exclude_list.empty()) { + if (!inverse && function.function_name == "regexp_full_match" && star.exclude_list.empty()) { // * SIMILAR TO '[regex]' is equivalent to COLUMNS('[regex]') so we can just move the expression directly child_expr = std::move(right); } else { @@ -207,16 +212,23 @@ void TryTransformStarLike(unique_ptr &root) { vector named_parameters; named_parameters.push_back("__lambda_col"); function.children[0] = make_uniq("__lambda_col"); + function.children[1] = std::move(right); + + unique_ptr lambda_body = std::move(expr); + if (inverse) { + vector> root_children; + root_children.push_back(std::move(lambda_body)); + lambda_body = make_uniq(ExpressionType::OPERATOR_NOT, std::move(root_children)); + } + auto lambda = make_uniq(std::move(named_parameters), std::move(lambda_body)); - auto lambda = make_uniq(std::move(named_parameters), std::move(root)); vector> filter_children; filter_children.push_back(std::move(star_expr)); filter_children.push_back(std::move(lambda)); - auto list_filter = make_uniq("list_filter", std::move(filter_children)); - child_expr = std::move(list_filter); + child_expr = make_uniq("list_filter", std::move(filter_children)); } - auto columns_expr = make_uniq(); + auto columns_expr = make_uniq(star.relation_name); columns_expr->columns = true; columns_expr->expr = std::move(child_expr); columns_expr->SetAlias(std::move(original_alias)); diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp index d413c88ed..7f03f0e32 100644 --- a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -13,20 +13,16 @@ class BoundSubqueryNode : public QueryNode { static constexpr const QueryNodeType TYPE = QueryNodeType::BOUND_SUBQUERY_NODE; public: - BoundSubqueryNode(shared_ptr subquery_binder, unique_ptr bound_node, + BoundSubqueryNode(shared_ptr subquery_binder, BoundStatement bound_node, unique_ptr subquery) : QueryNode(QueryNodeType::BOUND_SUBQUERY_NODE), subquery_binder(std::move(subquery_binder)), bound_node(std::move(bound_node)), subquery(std::move(subquery)) { } shared_ptr subquery_binder; - unique_ptr bound_node; + BoundStatement bound_node; unique_ptr subquery; - const vector> &GetSelectList() const override { - throw InternalException("Cannot get select list of bound subquery node"); - } - string ToString() const override { throw InternalException("Cannot ToString bound subquery node"); } @@ -116,15 +112,15 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept idx_t expected_columns = 1; if (expr.child) { auto &child = BoundExpression::GetExpression(*expr.child); - ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node->types); + ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node.types); if (child_expressions.empty()) { child_expressions.push_back(std::move(child)); } expected_columns = child_expressions.size(); } - if (bound_subquery.bound_node->types.size() != expected_columns) { + if (bound_subquery.bound_node.types.size() != expected_columns) { throw BinderException(expr, "Subquery returns %zu columns - expected %d", - bound_subquery.bound_node->types.size(), expected_columns); + bound_subquery.bound_node.types.size(), expected_columns); } } // both binding the child and binding the subquery was successful @@ -132,7 +128,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept auto subquery_binder = std::move(bound_subquery.subquery_binder); auto bound_node = std::move(bound_subquery.bound_node); LogicalType return_type = - expr.subquery_type == SubqueryType::SCALAR ? bound_node->types[0] : LogicalType(LogicalTypeId::BOOLEAN); + expr.subquery_type == SubqueryType::SCALAR ? bound_node.types[0] : LogicalType(LogicalTypeId::BOOLEAN); if (return_type.id() == LogicalTypeId::UNKNOWN) { return_type = LogicalType::SQLNULL; } @@ -144,7 +140,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept for (idx_t child_idx = 0; child_idx < child_expressions.size(); child_idx++) { auto &child = child_expressions[child_idx]; auto child_type = ExpressionBinder::GetExpressionReturnType(*child); - auto &subquery_type = bound_node->types[child_idx]; + auto &subquery_type = bound_node.types[child_idx]; LogicalType compare_type; if (!LogicalType::TryGetMaxLogicalType(context, child_type, subquery_type, compare_type)) { throw BinderException( diff --git a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp index fedcb8257..1fed1add2 100644 --- a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp @@ -15,29 +15,41 @@ #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { -unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, string key) { +static unique_ptr CreateBoundStructExtract(ClientContext &context, unique_ptr expr, + const vector &key_path, bool keep_parent_names) { vector> arguments; arguments.push_back(std::move(expr)); - arguments.push_back(make_uniq(Value(key))); + arguments.push_back(make_uniq(Value(key_path.back()))); auto extract_function = GetKeyExtractFunction(); - auto bind_info = extract_function.bind(context, extract_function, arguments); - auto return_type = extract_function.return_type; + auto bind_info = extract_function.GetBindCallback()(context, extract_function, arguments); + auto return_type = extract_function.GetReturnType(); auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), std::move(bind_info)); - result->SetAlias(std::move(key)); + + if (keep_parent_names) { + auto alias = StringUtil::Join(key_path, "."); + if (!alias.empty() && alias[0] == '.') { + alias = alias.substr(1); + } + result->SetAlias(alias); + } else { + result->SetAlias(key_path[0]); + } return std::move(result); } -unique_ptr CreateBoundStructExtractIndex(ClientContext &context, unique_ptr expr, idx_t key) { +static unique_ptr CreateBoundStructExtractIndex(ClientContext &context, unique_ptr expr, + idx_t key) { vector> arguments; arguments.push_back(std::move(expr)); arguments.push_back(make_uniq(Value::BIGINT(int64_t(key)))); auto extract_function = GetIndexExtractFunction(); - auto bind_info = extract_function.bind(context, extract_function, arguments); - auto return_type = extract_function.return_type; + auto bind_info = extract_function.GetBindCallback()(context, extract_function, arguments); + auto return_type = extract_function.GetReturnType(); auto result = make_uniq(return_type, std::move(extract_function), std::move(arguments), std::move(bind_info)); result->SetAlias("element" + to_string(key)); @@ -65,7 +77,7 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b ErrorData error; if (function.children.empty()) { - return BindResult(BinderException(function, "UNNEST() requires a single argument")); + return BindResult(BinderException(function, "UNNEST() requires at lease one argument")); } if (inside_window) { return BindResult(BinderException(function, UnsupportedUnnestMessage())); @@ -77,13 +89,10 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } idx_t max_depth = 1; + bool keep_parent_names = false; if (function.children.size() != 1) { - bool has_parameter = false; bool supported_argument = false; for (idx_t i = 1; i < function.children.size(); i++) { - if (has_parameter) { - return BindResult(BinderException(function, "UNNEST() only supports a single additional argument")); - } if (function.children[i]->HasParameter()) { throw ParameterNotAllowedException("Parameter not allowed in unnest parameter"); } @@ -107,17 +116,19 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b if (max_depth == 0) { throw BinderException("UNNEST cannot have a max depth of 0"); } + } else if (alias == "keep_parent_names") { + keep_parent_names = value.GetValue(); } else if (!alias.empty()) { throw BinderException("Unsupported parameter \"%s\" for unnest", alias); } else { break; } - has_parameter = true; supported_argument = true; } if (!supported_argument) { - return BindResult(BinderException(function, "UNNEST - unsupported extra argument, unnest only supports " - "recursive := [true/false] or max_depth := #")); + return BindResult(BinderException( + function, "UNNEST - unsupported extra argument, unnest only supports " + "recursive := [true/false], max_depth := # or keep_parent_names := [true/false]")); } } unnest_level++; @@ -216,7 +227,6 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b if (struct_unnests > 0) { vector> struct_expressions; struct_expressions.push_back(std::move(unnest_expr)); - for (idx_t i = 0; i < struct_unnests; i++) { vector> new_expressions; // check if there are any structs left @@ -232,7 +242,14 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } } else { for (auto &entry : child_types) { - new_expressions.push_back(CreateBoundStructExtract(context, expr->Copy(), entry.first)); + vector current_key_path; + // During recursive expansion, not all expressions are BoundFunctionExpression + if (keep_parent_names && expr->type == ExpressionType::BOUND_FUNCTION) { + current_key_path.push_back(expr->alias); + } + current_key_path.push_back(entry.first); + new_expressions.push_back( + CreateBoundStructExtract(context, expr->Copy(), current_key_path, keep_parent_names)); } } has_structs = true; diff --git a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp index 4b950a29c..00b7bcded 100644 --- a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp @@ -17,7 +17,6 @@ namespace duckdb { static LogicalType ResolveWindowExpressionType(ExpressionType window_type, const vector &child_types) { - idx_t param_count; switch (window_type) { case ExpressionType::WINDOW_RANK: @@ -115,7 +114,6 @@ static bool IsFillType(const LogicalType &type) { static LogicalType BindRangeExpression(ClientContext &context, const string &name, unique_ptr &expr, unique_ptr &order_expr) { - vector> children; D_ASSERT(order_expr.get()); diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 2a7cf8346..663abc595 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -1,93 +1,160 @@ -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression_map.hpp" -#include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/parser/query_node/list.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/main/query_result.hpp" namespace duckdb { -unique_ptr Binder::BindNode(CTENode &statement) { - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +struct BoundCTEData { + string ctename; + CTEMaterialize materialized; + idx_t setop_index; + shared_ptr child_binder; + shared_ptr cte_bind_state; +}; + +BoundStatement Binder::BindNode(QueryNode &node) { + reference current_binder(*this); + vector bound_ctes; + for (auto &cte : node.cte_map.map) { + bound_ctes.push_back(current_binder.get().PrepareCTE(cte.first, *cte.second)); + current_binder = *bound_ctes.back().child_binder; + } + BoundStatement result; + // now we bind the node + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::STATEMENT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + default: + throw InternalException("Unsupported query node type"); + } + for (idx_t i = bound_ctes.size(); i > 0; i--) { + auto &finish_binder = i == 1 ? *this : *bound_ctes[i - 2].child_binder; + result = finish_binder.FinishCTE(bound_ctes[i - 1], std::move(result)); + } + return result; +} - return BindCTE(statement); +CTEBindState::CTEBindState(Binder &parent_binder_p, QueryNode &cte_def_p, const vector &aliases_p) + : parent_binder(parent_binder_p), cte_def(cte_def_p), aliases(aliases_p), + active_binder_count(parent_binder.GetActiveBinders().size()) { } -unique_ptr Binder::BindCTE(CTENode &statement) { - auto result = make_uniq(); +CTEBindState::~CTEBindState() { +} - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); +bool CTEBindState::IsBound() const { + return query_binder.get() != nullptr; +} + +void CTEBindState::Bind(CTEBinding &binding) { + // we are lazily binding the CTE + // we need to bind it as if we were binding it during PrepareCTE + query_binder = Binder::CreateBinder(parent_binder.context, parent_binder); + query_binder->SetCanContainNulls(true); + + // we clear any expression binders that were added in the mean-time, to ensure we are not binding to any newly added + // correlated columns + auto &active_binders = parent_binder.GetActiveBinders(); + vector> stored_binders; + for (idx_t i = active_binder_count; i < active_binders.size(); i++) { + stored_binders.push_back(active_binders[i]); + } + active_binders.erase(active_binders.begin() + UnsafeNumericCast(active_binder_count), + active_binders.end()); - result->ctename = statement.ctename; - result->materialized = statement.materialized; - result->setop_index = GenerateTableIndex(); + // add this CTE to the query binder on the RHS with "CANNOT_BE_REFERENCED" to detect recursive references to + // ourselves + query_binder->bind_context.AddCTEBinding(binding.GetIndex(), binding.GetBindingAlias(), vector(), + vector(), CTEType::CANNOT_BE_REFERENCED); - AddCTE(result->ctename); + // bind the actual CTE + query = query_binder->Bind(cte_def); - result->query_binder = Binder::CreateBinder(context, this); - result->query = result->query_binder->BindNode(*statement.query); + // after binding - we add the active binders we removed back so we can leave the binder in its original state + for (auto &stored_binder : stored_binders) { + active_binders.push_back(stored_binder); + } // the result types of the CTE are the types of the LHS - result->types = result->query->types; + types = query.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->query->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + names = query.names; + for (idx_t i = 0; i < aliases.size() && i < names.size(); i++) { + names[i] = aliases[i]; } // Rename columns if duplicate names are detected - idx_t index = 1; - vector names; - // Use a case-insensitive set to track names - case_insensitive_set_t ci_names; - for (auto &n : result->names) { - string name = n; - while (ci_names.find(name) != ci_names.end()) { - name = n + "_" + std::to_string(index++); - } - names.push_back(name); - ci_names.insert(name); - } + QueryResult::DeduplicateColumns(names); +} + +BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement) { + BoundCTEData result; + + // first recursively visit the materialized CTE operations + // the left side is visited first and is added to the BindContext of the right side + D_ASSERT(statement.query); + + result.ctename = ctename; + result.materialized = statement.materialized; + result.setop_index = GenerateTableIndex(); - // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result->setop_index, statement.ctename, names, result->types); + // instead of eagerly binding the CTE here we add the CTE bind state to the list of CTE bindings + // the CTE is bound lazily - when referenced for the first time we perform the binding + result.cte_bind_state = make_shared_ptr(*this, *statement.query->node, statement.aliases); - result->child_binder = Binder::CreateBinder(context, this); + result.child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result->child_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->child_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, names, result->types); - - if (statement.child) { - // Move all modifiers to the child node. - for (auto &modifier : statement.modifiers) { - statement.child->modifiers.push_back(std::move(modifier)); - } + auto cte_binding = make_uniq(BindingAlias(ctename), result.cte_bind_state, result.setop_index); + result.child_binder->bind_context.AddCTEBinding(std::move(cte_binding)); + return result; +} - statement.modifiers.clear(); +BoundStatement Binder::FinishCTE(BoundCTEData &bound_cte, BoundStatement child) { + if (!bound_cte.cte_bind_state->IsBound()) { + // CTE was not bound - just ignore it + MoveCorrelatedExpressions(*bound_cte.child_binder); + return child; + } + auto &bind_state = *bound_cte.cte_bind_state; + for (auto &c : bind_state.query_binder->correlated_columns) { + bound_cte.child_binder->AddCorrelatedColumn(c); + } - result->child = result->child_binder->BindNode(*statement.child); - for (auto &c : result->query_binder->correlated_columns) { - result->child_binder->AddCorrelatedColumn(c); - } + BoundStatement result; + // the result types of the CTE are the types of the LHS + result.types = child.types; + result.names = child.names; - // the result types of the CTE are the types of the LHS - result->types = result->child->types; - result->names = result->child->names; + MoveCorrelatedExpressions(*bound_cte.child_binder); + MoveCorrelatedExpressions(*bind_state.query_binder); - MoveCorrelatedExpressions(*result->child_binder); - } + auto cte_query = std::move(bind_state.query.plan); + auto cte_child = std::move(child.plan); - MoveCorrelatedExpressions(*result->query_binder); + auto root = make_uniq(bound_cte.ctename, bound_cte.setop_index, result.types.size(), + std::move(cte_query), std::move(cte_child), bound_cte.materialized); + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || + bound_cte.child_binder->has_unplanned_dependent_joins || + bind_state.query_binder->has_unplanned_dependent_joins; + result.plan = std::move(root); return result; } diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 54e9e9fa5..efc9740de 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -3,14 +3,12 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" namespace duckdb { -unique_ptr Binder::BindNode(RecursiveCTENode &statement) { - auto result = make_uniq(); - +BoundStatement Binder::BindNode(RecursiveCTENode &statement) { // first recursively visit the recursive CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.left); @@ -19,53 +17,55 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw BinderException("UNION ALL cannot be used with USING KEY in recursive CTE."); } - result->ctename = statement.ctename; - result->union_all = statement.union_all; - result->setop_index = GenerateTableIndex(); + auto ctename = statement.ctename; + auto union_all = statement.union_all; + auto setop_index = GenerateTableIndex(); - result->left_binder = Binder::CreateBinder(context, this); - result->left = result->left_binder->BindNode(*statement.left); + auto left_binder = Binder::CreateBinder(context, this); + auto left = left_binder->BindNode(*statement.left); + BoundStatement result; // the result types of the CTE are the types of the LHS - result->types = result->left->types; + result.types = left.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->left->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + result.names = left.names; + for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { + result.names[i] = statement.aliases[i]; } // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); + bind_context.AddGenericBinding(setop_index, statement.ctename, result.names, result.types); - result->right_binder = Binder::CreateBinder(context, this); + auto right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first - // as we are binding a CTE currently, we take precendence over the existing binding. - // This implements the CTE shadowing behavior. - result->right_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->right_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, - result->types, !statement.key_targets.empty()); - - result->right = result->right_binder->BindNode(*statement.right); - for (auto &c : result->left_binder->correlated_columns) { - result->right_binder->AddCorrelatedColumn(c); + BindingAlias cte_alias(statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(cte_alias), result.names, result.types); + if (!statement.key_targets.empty()) { + BindingAlias recurring_alias("recurring", statement.ctename); + right_binder->bind_context.AddCTEBinding(setop_index, std::move(recurring_alias), result.names, result.types); + } + + auto right = right_binder->BindNode(*statement.right); + for (auto &c : left_binder->correlated_columns) { + right_binder->AddCorrelatedColumn(c); } // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result->left_binder); - MoveCorrelatedExpressions(*result->right_binder); + MoveCorrelatedExpressions(*left_binder); + MoveCorrelatedExpressions(*right_binder); + vector> key_targets; // bind specified keys to the referenced column auto expression_binder = ExpressionBinder(*this, context); - for (unique_ptr &expr : statement.key_targets) { + for (auto &expr : statement.key_targets) { auto bound_expr = expression_binder.Bind(expr); D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - result->key_targets.push_back(std::move(bound_expr)); + key_targets.push_back(std::move(bound_expr)); } // now both sides have been bound we can resolve types - if (result->left->types.size() != result->right->types.size()) { + if (left.types.size() != right.types.size()) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -74,7 +74,42 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); } - return std::move(result); + // Generate the logical plan for the left and right sides of the set operation + left_binder->is_outside_flattened = is_outside_flattened; + right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = std::move(left.plan); + auto right_node = std::move(right.plan); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || left_binder->has_unplanned_dependent_joins || + right_binder->has_unplanned_dependent_joins; + + // for both the left and right sides, cast them to the same types + left_node = CastLogicalOperatorToTypes(left.types, result.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(right.types, result.types, std::move(right_node)); + + auto recurring_binding = right_binder->GetCTEBinding(BindingAlias("recurring", ctename)); + bool ref_recurring = recurring_binding && recurring_binding->IsReferenced(); + if (key_targets.empty() && ref_recurring) { + throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); + } + + // Check if there is a reference to the recursive or recurring table, if not create a set operator. + auto cte_binding = right_binder->GetCTEBinding(BindingAlias(ctename)); + bool ref_cte = cte_binding && cte_binding->IsReferenced(); + if (!ref_cte && !ref_recurring) { + auto root = + make_uniq(setop_index, result.types.size(), std::move(left_node), + std::move(right_node), LogicalOperatorType::LOGICAL_UNION, union_all); + result.plan = std::move(root); + } else { + auto root = make_uniq(ctename, setop_index, result.types.size(), union_all, + std::move(key_targets), std::move(left_node), std::move(right_node)); + root->ref_recurring = ref_recurring; + result.plan = std::move(root); + } + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 4f52dfc4a..26859cb5b 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/function/function_binder.hpp" @@ -141,12 +142,27 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B } } order_binder.SetQueryComponent("DISTINCT ON"); + auto &order_binders = order_binder.GetBinders(); for (auto &distinct_on_target : distinct.distinct_on_targets) { - auto expr = BindOrderExpression(order_binder, std::move(distinct_on_target)); - if (!expr) { - continue; + vector> target_list; + order_binders[0].get().ExpandStarExpression(std::move(distinct_on_target), target_list); + for (auto &target : target_list) { + auto expr = BindOrderExpression(order_binder, std::move(target)); + if (!expr) { + continue; + } + // Skip duplicates + bool duplicate = false; + for (auto &existing : bound_distinct->target_distincts) { + if (expr->Equals(*existing)) { + duplicate = true; + break; + } + } + if (!duplicate) { + bound_distinct->target_distincts.push_back(std::move(expr)); + } } - bound_distinct->target_distincts.push_back(std::move(expr)); } order_binder.SetQueryComponent(); @@ -154,7 +170,6 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B break; } case ResultModifierType::ORDER_MODIFIER: { - auto &order = mod->Cast(); auto bound_order = make_uniq(); auto &config = DBConfig::GetConfig(context); @@ -363,7 +378,7 @@ void Binder::BindModifiers(BoundQueryNode &result, idx_t table_index, const vect } } -unique_ptr Binder::BindNode(SelectNode &statement) { +BoundStatement Binder::BindNode(SelectNode &statement) { D_ASSERT(statement.from_table); // first bind the FROM table statement @@ -403,21 +418,22 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { - D_ASSERT(from_table); +BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { + D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); - auto result = make_uniq(); - result->projection_index = GenerateTableIndex(); - result->group_index = GenerateTableIndex(); - result->aggregate_index = GenerateTableIndex(); - result->groupings_index = GenerateTableIndex(); - result->window_index = GenerateTableIndex(); - result->prune_index = GenerateTableIndex(); - - result->from_table = std::move(from_table); + auto result_ptr = make_uniq(); + auto &result = *result_ptr; + result.projection_index = GenerateTableIndex(); + result.group_index = GenerateTableIndex(); + result.aggregate_index = GenerateTableIndex(); + result.groupings_index = GenerateTableIndex(); + result.window_index = GenerateTableIndex(); + result.prune_index = GenerateTableIndex(); + + result.from_table = std::move(from_table); // bind the sample clause if (statement.sample) { - result->sample_options = std::move(statement.sample); + result.sample_options = std::move(statement.sample); } // visit the select list and expand any "*" statements @@ -429,19 +445,19 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } statement.select_list = std::move(new_select_list); - auto &bind_state = result->bind_state; + auto &bind_state = result.bind_state; for (idx_t i = 0; i < statement.select_list.size(); i++) { auto &expr = statement.select_list[i]; - result->names.push_back(expr->GetName()); + result.names.push_back(expr->GetName()); ExpressionBinder::QualifyColumnNames(*this, expr); if (!expr->GetAlias().empty()) { bind_state.alias_map[expr->GetAlias()] = i; - result->names[i] = expr->GetAlias(); + result.names[i] = expr->GetAlias(); } bind_state.projection_map[*expr] = i; bind_state.original_expressions.push_back(expr->Copy()); } - result->column_count = statement.select_list.size(); + result.column_count = statement.select_list.size(); // first visit the WHERE clause // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses @@ -452,12 +468,12 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ ColumnAliasBinder alias_binder(bind_state); WhereBinder where_binder(*this, context, &alias_binder); unique_ptr condition = std::move(statement.where_clause); - result->where_clause = where_binder.Bind(condition); + result.where_clause = where_binder.Bind(condition); } // now bind all the result modifiers; including DISTINCT and ORDER BY targets OrderBinder order_binder({*this}, statement, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); vector> unbound_groups; BoundGroupInformation info; @@ -465,9 +481,11 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!group_expressions.empty()) { // the statement has a GROUP BY clause, bind it unbound_groups.resize(group_expressions.size()); - GroupBinder group_binder(*this, context, statement, result->group_index, bind_state, info.alias_map); + GroupBinder group_binder(*this, context, statement, result.group_index, bind_state, info.alias_map); + // Allow NULL constants in GROUP BY to maintain their SQLNULL type + auto prev_can_contain_nulls = this->can_contain_nulls; + this->can_contain_nulls = true; for (idx_t i = 0; i < group_expressions.size(); i++) { - // we keep a copy of the unbound expression; // we keep the unbound copy around to check for group references in the SELECT and HAVING clause // the reason we want the unbound copy is because we want to figure out whether an expression @@ -489,7 +507,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!contains_subquery && requires_collation) { // if there is a collation on a group x, we should group by the collated expr, // but also push a first(x) aggregate in case x is selected (uncollated) - info.collated_groups[i] = result->aggregates.size(); + info.collated_groups[i] = result.aggregates.size(); auto first_fun = FirstFunctionGetter::GetFunction(bound_expr_ref.return_type); vector> first_children; @@ -499,9 +517,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ FunctionBinder function_binder(*this); auto function = function_binder.BindAggregateFunction(first_fun, std::move(first_children)); function->SetAlias("__collated_group"); - result->aggregates.push_back(std::move(function)); + result.aggregates.push_back(std::move(function)); } - result->groups.group_expressions.push_back(std::move(bound_expr)); + result.groups.group_expressions.push_back(std::move(bound_expr)); // in the unbound expression we DO bind the table names of any ColumnRefs // we do this to make sure that "table.a" and "a" are treated the same @@ -511,14 +529,15 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ ExpressionBinder::QualifyColumnNames(*this, unbound_groups[i]); info.map[*unbound_groups[i]] = i; } + this->can_contain_nulls = prev_can_contain_nulls; } - result->groups.grouping_sets = std::move(statement.groups.grouping_sets); + result.groups.grouping_sets = std::move(statement.groups.grouping_sets); // bind the HAVING clause, if any if (statement.having) { - HavingBinder having_binder(*this, context, *result, info, statement.aggregate_handling); + HavingBinder having_binder(*this, context, result, info, statement.aggregate_handling); ExpressionBinder::QualifyColumnNames(having_binder, statement.having); - result->having = having_binder.Bind(statement.having); + result.having = having_binder.Bind(statement.having); } // bind the QUALIFY clause, if any @@ -527,9 +546,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); } - QualifyBinder qualify_binder(*this, context, *result, info); + QualifyBinder qualify_binder(*this, context, result, info); ExpressionBinder::QualifyColumnNames(*this, statement.qualify); - result->qualify = qualify_binder.Bind(statement.qualify); + result.qualify = qualify_binder.Bind(statement.qualify); if (qualify_binder.HasBoundColumns()) { if (qualify_binder.BoundAggregates()) { throw BinderException("Cannot mix aggregates with non-aggregated columns!"); @@ -539,7 +558,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } // after that, we bind to the SELECT list - SelectBinder select_binder(*this, context, *result, info); + SelectBinder select_binder(*this, context, result, info); // if we expand select-list expressions, e.g., via UNNEST, then we need to possibly // adjust the column index of the already bound ORDER BY modifiers, and not only set their types @@ -549,13 +568,13 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (idx_t i = 0; i < statement.select_list.size(); i++) { bool is_window = statement.select_list[i]->IsWindow(); - idx_t unnest_count = result->unnests.size(); + idx_t unnest_count = result.unnests.size(); LogicalType result_type; auto expr = select_binder.Bind(statement.select_list[i], &result_type, true); - bool is_original_column = i < result->column_count; + bool is_original_column = i < result.column_count; bool can_group_by_all = statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES && is_original_column; - result->bound_column_count++; + result.bound_column_count++; if (expr->GetExpressionType() == ExpressionType::BOUND_EXPANDED) { if (!is_original_column) { @@ -571,9 +590,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (auto &struct_expr : struct_expressions) { new_names.push_back(struct_expr->GetName()); - result->types.push_back(struct_expr->return_type); + result.types.push_back(struct_expr->return_type); internal_sql_types.push_back(struct_expr->return_type); - result->select_list.push_back(std::move(struct_expr)); + result.select_list.push_back(std::move(struct_expr)); } bind_state.AddExpandedColumn(struct_expressions.size()); continue; @@ -594,7 +613,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (is_window) { throw BinderException("Cannot group on a window clause"); } - if (result->unnests.size() > unnest_count) { + if (result.unnests.size() > unnest_count) { throw BinderException("Cannot group on an UNNEST or UNLIST clause"); } // we are forcing aggregates, and the node has columns bound @@ -602,10 +621,10 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ group_by_all_indexes.push_back(i); } - result->select_list.push_back(std::move(expr)); + result.select_list.push_back(std::move(expr)); if (is_original_column) { - new_names.push_back(std::move(result->names[i])); - result->types.push_back(result_type); + new_names.push_back(std::move(result.names[i])); + result.types.push_back(result_type); } internal_sql_types.push_back(result_type); @@ -617,31 +636,31 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // push the GROUP BY ALL expressions into the group set for (auto &group_by_all_index : group_by_all_indexes) { - auto &expr = result->select_list[group_by_all_index]; + auto &expr = result.select_list[group_by_all_index]; auto group_ref = make_uniq( - expr->return_type, ColumnBinding(result->group_index, result->groups.group_expressions.size())); - result->groups.group_expressions.push_back(std::move(expr)); + expr->return_type, ColumnBinding(result.group_index, result.groups.group_expressions.size())); + result.groups.group_expressions.push_back(std::move(expr)); expr = std::move(group_ref); } set group_by_all_indexes_set; if (!group_by_all_indexes.empty()) { - idx_t num_set_indexes = result->groups.group_expressions.size(); + idx_t num_set_indexes = result.groups.group_expressions.size(); for (idx_t i = 0; i < num_set_indexes; i++) { group_by_all_indexes_set.insert(i); } - D_ASSERT(result->groups.grouping_sets.empty()); - result->groups.grouping_sets.push_back(group_by_all_indexes_set); + D_ASSERT(result.groups.grouping_sets.empty()); + result.groups.grouping_sets.push_back(group_by_all_indexes_set); } - result->column_count = new_names.size(); - result->names = std::move(new_names); - result->need_prune = result->select_list.size() > result->column_count; + result.column_count = new_names.size(); + result.names = std::move(new_names); + result.need_prune = result.select_list.size() > result.column_count; // in the normal select binder, we bind columns as if there is no aggregation // i.e. in the query [SELECT i, SUM(i) FROM integers;] the "i" will be bound as a normal column // since we have an aggregation, we need to either (1) throw an error, or (2) wrap the column in a FIRST() aggregate // we choose the former one [CONTROVERSIAL: this is the PostgreSQL behavior] - if (!result->groups.group_expressions.empty() || !result->aggregates.empty() || statement.having || - !result->groups.grouping_sets.empty()) { + if (!result.groups.group_expressions.empty() || !result.aggregates.empty() || statement.having || + !result.groups.grouping_sets.empty()) { if (statement.aggregate_handling == AggregateHandling::NO_AGGREGATES_ALLOWED) { throw BinderException("Aggregates cannot be present in a Project relation!"); } else { @@ -672,13 +691,19 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // QUALIFY clause requires at least one window function to be specified in at least one of the SELECT column list or // the filter predicate of the QUALIFY clause - if (statement.qualify && result->windows.empty()) { + if (statement.qualify && result.windows.empty()) { throw BinderException("at least one window function must appear in the SELECT column or QUALIFY clause"); } // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions - BindModifiers(*result, result->projection_index, result->names, internal_sql_types, bind_state); - return std::move(result); + BindModifiers(result, result.projection_index, result.names, internal_sql_types, bind_state); + + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.original_expressions = std::move(result.bind_state.original_expressions); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index 50c6b3c06..34ea09f51 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -10,89 +10,109 @@ #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/common/enum_util.hpp" namespace duckdb { -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state, const vector &reorder_idx) { - if (node.type == QueryNodeType::SET_OPERATION_NODE) { - // setop, recurse - auto &setop = node.Cast(); +struct SetOpAliasGatherer { +public: + explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { + } - // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { - // for UNION BY NAME - create a new re-order index - case_insensitive_map_t reorder_map; - for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { - reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; - } + void GatherAliases(BoundStatement &stmt, const vector &reorder_idx); + void GatherSetOpAliases(SetOperationType setop_type, const vector &names, + vector &bound_children, const vector &reorder_idx); - // use new reorder index - for (auto &child : setop.bound_children) { - vector new_reorder_idx; - for (idx_t col_idx = 0; col_idx < child.node->names.size(); col_idx++) { - auto &col_name = child.node->names[col_idx]; - auto entry = reorder_map.find(col_name); - if (entry == reorder_map.end()) { - throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); - } - new_reorder_idx.push_back(entry->second); - } - GatherAliases(*child.node, bind_state, new_reorder_idx); - } - return; - } +private: + SelectBindState &bind_state; +}; - for (auto &child : setop.bound_children) { - GatherAliases(*child.node, bind_state, reorder_idx); - } - } else { - // query node - D_ASSERT(node.type == QueryNodeType::SELECT_NODE); - auto &select = node.Cast(); - // fill the alias lists with the names - D_ASSERT(reorder_idx.size() == select.names.size()); - for (idx_t i = 0; i < select.names.size(); i++) { - auto &name = select.names[i]; - // first check if the alias is already in there - auto entry = bind_state.alias_map.find(name); +void SetOpAliasGatherer::GatherAliases(BoundStatement &stmt, const vector &reorder_idx) { + if (stmt.extra_info.setop_type != SetOperationType::NONE) { + GatherSetOpAliases(stmt.extra_info.setop_type, stmt.names, stmt.extra_info.bound_children, reorder_idx); + return; + } - idx_t index = reorder_idx[i]; + // query node + auto &select_names = stmt.names; + // fill the alias lists with the names + D_ASSERT(reorder_idx.size() == select_names.size()); + for (idx_t i = 0; i < select_names.size(); i++) { + auto &name = select_names[i]; + // first check if the alias is already in there + auto entry = bind_state.alias_map.find(name); - if (entry == bind_state.alias_map.end()) { - // the alias is not in there yet, just assign it - bind_state.alias_map[name] = index; + idx_t index = reorder_idx[i]; + + if (entry == bind_state.alias_map.end()) { + // the alias is not in there yet, just assign it + bind_state.alias_map[name] = index; + } + } + // check if the expression matches one of the expressions in the original expression list + auto &select_list = stmt.extra_info.original_expressions; + for (idx_t i = 0; i < select_list.size(); i++) { + auto &expr = select_list[i]; + idx_t index = reorder_idx[i]; + // now check if the node is already in the set of expressions + auto expr_entry = bind_state.projection_map.find(*expr); + if (expr_entry != bind_state.projection_map.end()) { + // the node is in there + // repeat the same as with the alias: if there is an ambiguity we insert "-1" + if (expr_entry->second != index) { + bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; } + } else { + // not in there yet, just place it in there + bind_state.projection_map[*expr] = index; + } + } +} + +void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, + vector &bound_children, const vector &reorder_idx) { + // create new reorder index + if (setop_type == SetOperationType::UNION_BY_NAME) { + auto &setop_names = stmt_names; + // for UNION BY NAME - create a new re-order index + case_insensitive_map_t reorder_map; + for (idx_t col_idx = 0; col_idx < setop_names.size(); ++col_idx) { + reorder_map[setop_names[col_idx]] = reorder_idx[col_idx]; } - // check if the expression matches one of the expressions in the original expression list - for (idx_t i = 0; i < select.bind_state.original_expressions.size(); i++) { - auto &expr = select.bind_state.original_expressions[i]; - idx_t index = reorder_idx[i]; - // now check if the node is already in the set of expressions - auto expr_entry = bind_state.projection_map.find(*expr); - if (expr_entry != bind_state.projection_map.end()) { - // the node is in there - // repeat the same as with the alias: if there is an ambiguity we insert "-1" - if (expr_entry->second != index) { - bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; + + // use new reorder index + for (auto &child : bound_children) { + vector new_reorder_idx; + auto &child_names = child.names; + for (idx_t col_idx = 0; col_idx < child_names.size(); col_idx++) { + auto &col_name = child_names[col_idx]; + auto entry = reorder_map.find(col_name); + if (entry == reorder_map.end()) { + throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); } - } else { - // not in there yet, just place it in there - bind_state.projection_map[*expr] = index; + new_reorder_idx.push_back(entry->second); } + GatherAliases(child, new_reorder_idx); + } + } else { + for (auto &child : bound_children) { + GatherAliases(child, reorder_idx); } } } -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state) { +static void GatherAliases(BoundSetOperationNode &root, vector &child_statements, + SelectBindState &bind_state) { + SetOpAliasGatherer gatherer(bind_state); vector reorder_idx; - for (idx_t i = 0; i < node.names.size(); i++) { + for (idx_t i = 0; i < root.names.size(); i++) { reorder_idx.push_back(i); } - GatherAliases(node, bind_state, reorder_idx); + gatherer.GatherSetOpAliases(root.setop_type, root.names, child_statements, reorder_idx); } -static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode &result, bool can_contain_nulls) { +void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); vector> node_name_maps; case_insensitive_set_t global_name_set; @@ -101,10 +121,10 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { - auto &child_node = *child.node; + auto &child_names = child.names; case_insensitive_map_t node_name_map; - for (idx_t i = 0; i < child_node.names.size(); ++i) { - auto &col_name = child_node.names[i]; + for (idx_t i = 0; i < child_names.size(); ++i) { + auto &col_name = child_names[i]; if (node_name_map.find(col_name) != node_name_map.end()) { throw BinderException( "UNION (ALL) BY NAME operation doesn't support duplicate names in the SELECT list - " @@ -129,7 +149,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & auto &col_name = result.names[i]; LogicalType result_type(LogicalTypeId::INVALID); for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { - auto &child = result.bound_children[child_idx]; + auto &child_types = result.bound_children[child_idx].types; auto &child_name_map = node_name_maps[child_idx]; // check if the column exists in this child node auto entry = child_name_map.find(col_name); @@ -137,7 +157,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & need_reorder = true; } else { auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; + auto &child_col_type = child_types[col_idx_in_child]; // the child exists in this node - compute the type if (result_type.id() == LogicalTypeId::INVALID) { result_type = child_col_type; @@ -165,6 +185,8 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & return; } // If reorder is required, generate the expressions for each node + vector>> reorder_expressions; + reorder_expressions.resize(result.bound_children.size()); for (idx_t i = 0; i < new_size; ++i) { auto &col_name = result.names[i]; for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { @@ -179,34 +201,48 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & } else { // the column exists - reference it auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; - expr = make_uniq(child_col_type, - ColumnBinding(child.node->GetRootIndex(), col_idx_in_child)); + auto &child_col_type = child.types[col_idx_in_child]; + auto root_idx = child.plan->GetRootIndex(); + expr = make_uniq(child_col_type, ColumnBinding(root_idx, col_idx_in_child)); } - child.reorder_expressions.push_back(std::move(expr)); + reorder_expressions[child_idx].push_back(std::move(expr)); } } + // now push projections for each node + for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { + auto &child = result.bound_children[child_idx]; + auto &child_reorder_expressions = reorder_expressions[child_idx]; + // if we have re-order expressions push a projection + vector child_types; + for (auto &expr : child_reorder_expressions) { + child_types.push_back(expr->return_type); + } + auto child_projection = + make_uniq(GenerateTableIndex(), std::move(child_reorder_expressions)); + child_projection->children.push_back(std::move(child.plan)); + child.plan = std::move(child_projection); + child.types = std::move(child_types); + } } -static void GatherSetOpBinders(BoundQueryNode &node, Binder &binder, vector> &binders) { - if (node.type != QueryNodeType::SET_OPERATION_NODE) { - binders.push_back(binder); - return; +static void GatherSetOpBinders(vector &children, vector> &binders, + vector> &result) { + for (auto &child_binder : binders) { + result.push_back(*child_binder); } - auto &setop_node = node.Cast(); - for (auto &child : setop_node.bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); + for (auto &child_node : children) { + GatherSetOpBinders(child_node.extra_info.bound_children, child_node.extra_info.child_binders, result); } } -unique_ptr Binder::BindNode(SetOperationNode &statement) { - auto result = make_uniq(); - result->setop_type = statement.setop_type; - result->setop_all = statement.setop_all; +BoundStatement Binder::BindNode(SetOperationNode &statement) { + BoundSetOperationNode result; + result.setop_type = statement.setop_type; + result.setop_all = statement.setop_all; // first recursively visit the set operations // all children have an independent BindContext and Binder - result->setop_index = GenerateTableIndex(); + result.setop_index = GenerateTableIndex(); if (statement.children.size() < 2) { throw InternalException("Set Operations must have at least 2 children"); } @@ -215,27 +251,23 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { throw InternalException("Set Operation type must have exactly 2 children - except for UNION/UNION_BY_NAME"); } for (auto &child : statement.children) { - BoundSetOpChild bound_child; - bound_child.binder = Binder::CreateBinder(context, this); - bound_child.binder->can_contain_nulls = true; - bound_child.node = bound_child.binder->BindNode(*child); - result->bound_children.push_back(std::move(bound_child)); - } - - // move the correlated expressions from the child binders to this binder - for (auto &bound_child : result->bound_children) { - MoveCorrelatedExpressions(*bound_child.binder); + auto child_binder = Binder::CreateBinder(context, this); + child_binder->can_contain_nulls = true; + auto child_node = child_binder->BindNode(*child); + MoveCorrelatedExpressions(*child_binder); + result.bound_children.push_back(std::move(child_node)); + result.child_binders.push_back(std::move(child_binder)); } - if (result->setop_type == SetOperationType::UNION_BY_NAME) { + if (result.setop_type == SetOperationType::UNION_BY_NAME) { // UNION BY NAME - merge the columns from all sides - BuildUnionByNameInfo(context, *result, can_contain_nulls); + BuildUnionByNameInfo(result); } else { // UNION ALL BY POSITION - the columns of both sides must match exactly - result->names = result->bound_children[0].node->names; - auto result_columns = result->bound_children[0].node->types.size(); - for (idx_t i = 1; i < result->bound_children.size(); ++i) { - if (result->bound_children[i].node->types.size() != result_columns) { + result.names = result.bound_children[0].names; + auto result_columns = result.bound_children[0].types.size(); + for (idx_t i = 1; i < result.bound_children.size(); ++i) { + if (result.bound_children[i].types.size() != result_columns) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -243,40 +275,48 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { // figure out the types of the setop result by picking the max of both for (idx_t i = 0; i < result_columns; i++) { - auto result_type = result->bound_children[0].node->types[i]; - for (idx_t child_idx = 1; child_idx < result->bound_children.size(); ++child_idx) { - auto &child_node = *result->bound_children[child_idx].node; - result_type = LogicalType::ForceMaxLogicalType(result_type, child_node.types[i]); + auto result_type = result.bound_children[0].types[i]; + for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { + auto &child_types = result.bound_children[child_idx].types; + result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); } if (!can_contain_nulls) { if (ExpressionBinder::ContainsNullType(result_type)) { result_type = ExpressionBinder::ExchangeNullType(result_type); } } - result->types.push_back(result_type); + result.types.push_back(result_type); } } + if (!statement.setop_all) { + statement.modifiers.insert(statement.modifiers.begin(), make_uniq()); + statement.setop_all = false; // Already handled + } + SelectBindState bind_state; if (!statement.modifiers.empty()) { // handle the ORDER BY/DISTINCT clauses - - // we recursively visit the children of this node to extract aliases and expressions that can be referenced - // in the ORDER BYs - GatherAliases(*result, bind_state); + vector> binders; + GatherSetOpBinders(result.bound_children, result.child_binders, binders); + GatherAliases(result, result.bound_children, bind_state); // now we perform the actual resolution of the ORDER BY/DISTINCT expressions - vector> binders; - for (auto &child : result->bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); - } OrderBinder order_binder(binders, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); } // finally bind the types of the ORDER/DISTINCT clause expressions - BindModifiers(*result, result->setop_index, result->names, result->types, bind_state); - return std::move(result); + BindModifiers(result, result.setop_index, result.names, result.types, bind_state); + + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.setop_type = statement.setop_type; + result_statement.extra_info.bound_children = std::move(result.bound_children); + result_statement.extra_info.child_binders = std::move(result.child_binders); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp new file mode 100644 index 000000000..6f6f9941a --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp @@ -0,0 +1,26 @@ +#include "duckdb/parser/query_node/statement_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BoundStatement Binder::BindNode(StatementNode &statement) { + // switch on type here to ensure we bind WITHOUT ctes to prevent infinite recursion + switch (statement.stmt.type) { + case StatementType::INSERT_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::DELETE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::UPDATE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::MERGE_INTO_STATEMENT: + return Bind(statement.stmt.Cast()); + default: + return Bind(statement.stmt); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp deleted file mode 100644 index 5bd06c0e5..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/operator/logical_materialized_cte.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTENode &node) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - auto cte_child = CreatePlan(*node.child); - - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), - std::move(cte_query), std::move(cte_child), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - - return VisitQueryNode(node, std::move(root)); -} - -unique_ptr Binder::CreatePlan(BoundCTENode &node, unique_ptr base) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - unique_ptr root; - if (node.child && node.child->type == QueryNodeType::CTE_NODE) { - root = CreatePlan(node.child->Cast(), std::move(base)); - } else if (node.child) { - root = CreatePlan(*node.child); - } else { - root = std::move(base); - } - - // Only keep the materialized CTE, if it is used - if (node.child_binder->bind_context.cte_references[node.ctename] && - *node.child_binder->bind_context.cte_references[node.ctename] > 0) { - - // Push the CTE through single-child operators so query modifiers appear ABOVE the CTE (internal issue #2652) - // Otherwise, we may have a LIMIT on top of the CTE, and an ORDER BY in the query, and we can't make a TopN - reference> cte_child = root; - while (cte_child.get()->children.size() == 1 && cte_child.get()->type != LogicalOperatorType::LOGICAL_CTE_REF) { - cte_child = cte_child.get()->children[0]; - } - cte_child.get() = - make_uniq(node.ctename, node.setop_index, node.types.size(), std::move(cte_query), - std::move(cte_child.get()), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || - node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - } - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp deleted file mode 100644 index 4064136b6..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_recursive_cte.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = node.left_binder->CreatePlan(*node.left); - auto right_node = node.right_binder->CreatePlan(*node.right); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.left_binder->has_unplanned_dependent_joins || - node.right_binder->has_unplanned_dependent_joins; - - // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); - - bool ref_recurring = node.right_binder->bind_context.cte_references["recurring." + node.ctename] && - *node.right_binder->bind_context.cte_references["recurring." + node.ctename] != 0; - - if (node.key_targets.empty() && ref_recurring) { - throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); - } - - // Check if there is a reference to the recursive or recurring table, if not create a set operator. - if ((!node.right_binder->bind_context.cte_references[node.ctename] || - *node.right_binder->bind_context.cte_references[node.ctename] == 0) && - !ref_recurring) { - auto root = - make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), LogicalOperatorType::LOGICAL_UNION, node.union_all); - return VisitQueryNode(node, std::move(root)); - } - - auto root = - make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, - std::move(node.key_targets), std::move(left_node), std::move(right_node)); - root->ref_recurring = ref_recurring; - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp index 46e5d2e12..335ccdb1f 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp @@ -16,10 +16,8 @@ unique_ptr Binder::PlanFilter(unique_ptr condition, } unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { - unique_ptr root; - D_ASSERT(statement.from_table); - root = CreatePlan(*statement.from_table); - D_ASSERT(root); + D_ASSERT(statement.from_table.plan); + auto root = std::move(statement.from_table.plan); // plan the sample clause if (statement.sample_options) { @@ -30,7 +28,7 @@ unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { root = PlanFilter(std::move(statement.where_clause), std::move(root)); } - if (!statement.aggregates.empty() || !statement.groups.group_expressions.empty()) { + if (!statement.aggregates.empty() || !statement.groups.group_expressions.empty() || statement.having) { if (!statement.groups.group_expressions.empty()) { // visit the groups for (auto &group : statement.groups.group_expressions) { diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index 9b0fa7c94..a1a7f60b0 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -10,8 +10,8 @@ namespace duckdb { // Optionally push a PROJECTION operator -unique_ptr Binder::CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, +unique_ptr Binder::CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op) { D_ASSERT(op); // first check if we even need to cast @@ -113,29 +113,16 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { D_ASSERT(node.bound_children.size() >= 2); vector> children; - for (auto &child : node.bound_children) { - child.binder->is_outside_flattened = is_outside_flattened; + for (idx_t child_idx = 0; child_idx < node.bound_children.size(); child_idx++) { + auto &child = node.bound_children[child_idx]; + auto &child_binder = *node.child_binders[child_idx]; // construct the logical plan for the child node - auto child_node = child.binder->CreatePlan(*child.node); - if (!child.reorder_expressions.empty()) { - // if we have re-order expressions push a projection - vector child_types; - for (auto &expr : child.reorder_expressions) { - child_types.push_back(expr->return_type); - } - auto child_projection = - make_uniq(GenerateTableIndex(), std::move(child.reorder_expressions)); - child_projection->children.push_back(std::move(child_node)); - child_node = std::move(child_projection); - - child_node = CastLogicalOperatorToTypes(child_types, node.types, std::move(child_node)); - } else { - // otherwise push only casts - child_node = CastLogicalOperatorToTypes(child.node->types, node.types, std::move(child_node)); - } + auto child_node = std::move(child.plan); + // push casts for the target types + child_node = CastLogicalOperatorToTypes(child.types, node.types, std::move(child_node)); // check if there are any unplanned subqueries left in any child - if (child.binder->has_unplanned_dependent_joins) { + if (child_binder.has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } children.push_back(std::move(child_node)); diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 2664903d3..29a419ab7 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -186,9 +186,10 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq } } -static unique_ptr -CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, - unique_ptr original_plan, bool perform_delim) { +static unique_ptr CreateDuplicateEliminatedJoin(const CorrelatedColumns &correlated_columns, + JoinType join_type, + unique_ptr original_plan, + bool perform_delim) { auto delim_join = make_uniq(join_type); delim_join->correlated_columns = correlated_columns; delim_join->perform_delim = perform_delim; @@ -216,7 +217,7 @@ static bool PerformDelimOnType(const LogicalType &type) { return true; } -static bool PerformDuplicateElimination(Binder &binder, vector &correlated_columns) { +static bool PerformDuplicateElimination(Binder &binder, CorrelatedColumns &correlated_columns) { if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { // if optimizations are disabled we always do a delim join return true; @@ -235,7 +236,8 @@ static bool PerformDuplicateElimination(Binder &binder, vector Binder::PlanSubquery(BoundSubqueryExpression &expr, uniqu // first we translate the QueryNode of the subquery into a logical plan auto sub_binder = Binder::CreateBinder(context, this); sub_binder->is_outside_flattened = false; - auto subquery_root = sub_binder->CreatePlan(*expr.subquery); + auto subquery_root = std::move(expr.subquery.plan); D_ASSERT(subquery_root); // now we actually flatten the subquery @@ -403,7 +405,7 @@ void Binder::PlanSubqueries(unique_ptr &expr_ptr, unique_ptr Binder::PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated, JoinType join_type, + CorrelatedColumns &correlated, JoinType join_type, unique_ptr condition) { // scan the right operator for correlated columns // correlated LATERAL JOIN diff --git a/src/duckdb/src/planner/binder/statement/bind_attach.cpp b/src/duckdb/src/planner/binder/statement/bind_attach.cpp index 0e8655d2f..6da075e25 100644 --- a/src/duckdb/src/planner/binder/statement/bind_attach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_attach.cpp @@ -1,7 +1,6 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/attach_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -29,7 +28,7 @@ BoundStatement Binder::Bind(AttachStatement &stmt) { result.plan = make_uniq(LogicalOperatorType::LOGICAL_ATTACH, std::move(stmt.info)); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_call.cpp b/src/duckdb/src/planner/binder/statement/bind_call.cpp index ba96927e8..a746e1689 100644 --- a/src/duckdb/src/planner/binder/statement/bind_call.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_call.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/statement/call_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" @@ -19,7 +17,7 @@ BoundStatement Binder::Bind(CallStatement &stmt) { auto result = Bind(select_statement); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_copy.cpp b/src/duckdb/src/planner/binder/statement/bind_copy.cpp index b7881a0a1..7a6944b4a 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/bind_helpers.hpp" #include "duckdb/common/filename_pattern.hpp" #include "duckdb/common/local_file_system.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/function/table/read_csv.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" @@ -36,7 +37,7 @@ void IsFormatExtensionKnown(const string &format) { // It's a match, we must throw throw CatalogException( "Copy Function with name \"%s\" is not in the catalog, but it exists in the %s extension.", format, - file_postfixes.extension); + std::string(file_postfixes.extension)); } } } @@ -115,7 +116,7 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct PreserveOrderType preserve_order = PreserveOrderType::AUTOMATIC; CopyFunctionReturnType return_type = CopyFunctionReturnType::CHANGED_ROWS; - CopyFunctionBindInput bind_input(*stmt.info); + CopyFunctionBindInput bind_input(*stmt.info, function.function_info); bind_input.file_extension = function.extension; @@ -251,7 +252,6 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct auto new_select_list = function.copy_to_select(input); if (!new_select_list.empty()) { - // We have a new select list, create a projection on top of the current plan auto projection = make_uniq(GenerateTableIndex(), std::move(new_select_list)); projection->children.push_back(std::move(select_node.plan)); @@ -423,7 +423,10 @@ vector BindCopyOption(ClientContext &context, TableFunctionBinder &option } } auto bound_expr = option_binder.Bind(expr); - auto val = ExpressionExecutor::EvaluateScalar(context, *bound_expr); + if (bound_expr->HasParameter()) { + throw ParameterNotResolvedException(); + } + auto val = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true); if (val.IsNull()) { throw BinderException("NULL is not supported as a valid option for COPY option \"" + name + "\""); } @@ -551,8 +554,8 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { // check if this matches the mode if (copy_option.mode != CopyOptionMode::READ_WRITE && copy_option.mode != copy_mode) { throw InvalidInputException("Option \"%s\" is not supported for %s - only for %s", provided_option, - stmt.info->is_from ? "reading" : "writing", - stmt.info->is_from ? "writing" : "reading"); + std::string(stmt.info->is_from ? "reading" : "writing"), + std::string(stmt.info->is_from ? "writing" : "reading")); } if (copy_option.type.id() != LogicalTypeId::ANY) { if (provided_entry.second.empty()) { @@ -599,7 +602,7 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { } auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; if (stmt.info->is_from) { return BindCopyFrom(stmt, function); diff --git a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp index d2c0a03fb..9c11df4e9 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp @@ -21,7 +21,6 @@ namespace duckdb { unique_ptr Binder::BindCopyDatabaseSchema(Catalog &from_database, const string &target_database_name) { - catalog_entry_vector_t catalog_entries; catalog_entries = PhysicalExport::GetNaiveExportOrder(context, from_database); @@ -125,9 +124,13 @@ BoundStatement Binder::Bind(CopyDatabaseStatement &stmt) { result.plan = std::move(plan); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; - properties.RegisterDBModify(target_catalog, context); + + DatabaseModificationType modification; + modification |= DatabaseModificationType::INSERT_DATA; + modification |= DatabaseModificationType::CREATE_CATALOG_ENTRY; + properties.RegisterDBModify(target_catalog, context, modification); return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index 76b43f60a..731d2fc93 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -39,7 +39,6 @@ #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/storage/storage_extension.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/type_visitor.hpp" @@ -120,11 +119,11 @@ void Binder::SearchSchema(CreateInfo &info) { if (!info.temporary) { // non-temporary create: not read only if (info.catalog == TEMP_CATALOG) { - throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } else { if (info.catalog != TEMP_CATALOG) { - throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); + throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", std::string(TEMP_CATALOG)); } } } @@ -137,7 +136,7 @@ SchemaCatalogEntry &Binder::BindSchema(CreateInfo &info) { info.schema = schema_obj.name; if (!info.temporary) { auto &properties = GetStatementProperties(); - properties.RegisterDBModify(schema_obj.catalog, context); + properties.RegisterDBModify(schema_obj.catalog, context, DatabaseModificationType::CREATE_CATALOG_ENTRY); } return schema_obj; } @@ -221,13 +220,14 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { // Figure out if we can store typed macro parameters auto &attached = catalog.GetAttached(); - auto store_types = info.temporary || attached.IsTemporary(); + auto store_types = true; if (attached.HasStorageManager()) { + // If DuckDB is used as a storage, we must check the version. auto &storage_manager = attached.GetStorageManager(); const auto since = SerializationCompatibility::FromString("v1.4.0").serialization_version; - store_types |= storage_manager.InMemory() || storage_manager.GetStorageVersion() >= since; + store_types = info.temporary || attached.IsTemporary() || storage_manager.InMemory() || + storage_manager.GetStorageVersion() >= since; } - // try to bind each of the included functions vector_of_logical_type_set_t type_overloads; auto &base = info.Cast(); @@ -345,11 +345,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { try { dummy_binder->Bind(*query_node); } catch (const std::exception &ex) { - // TODO: we would like to do something like "error = ErrorData(ex);" here, - // but that breaks macro's like "create macro m(x) as table (from query_table(x));", - // because dummy-binding these always throws an error instead of a ParameterNotResolvedException. - // So, for now, we allow macro's with bind errors to be created. - // Binding is still useful because we can create the dependencies. + error = ErrorData(ex); } } @@ -509,7 +505,8 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { case CatalogType::SCHEMA_ENTRY: { auto &base = stmt.info->Cast(); auto catalog = BindCatalog(base.catalog); - properties.RegisterDBModify(Catalog::GetCatalog(context, catalog), context); + properties.RegisterDBModify(Catalog::GetCatalog(context, catalog), context, + DatabaseModificationType::CREATE_CATALOG_ENTRY); result.plan = make_uniq(LogicalOperatorType::LOGICAL_CREATE_SCHEMA, std::move(stmt.info)); break; } @@ -548,23 +545,21 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { create_index_info.table); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + auto plan = std::move(bound_table.plan); + if (plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("can only create an index on a base table"); + } + auto &get = plan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { throw BinderException("can only create an index on a base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto &table = *table_ptr; if (table.temporary) { stmt.info->temporary = true; } - properties.RegisterDBModify(table.catalog, context); - - // create a plan over the bound table - auto plan = CreatePlan(*bound_table); - if (plan->type != LogicalOperatorType::LOGICAL_GET) { - throw BinderException("Cannot create index on a view!"); - } - + properties.RegisterDBModify(table.catalog, context, DatabaseModificationType::CREATE_INDEX); result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); break; } @@ -718,7 +713,7 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { throw InternalException("Unrecognized type!"); } properties.return_type = StatementReturnType::NOTHING; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp index ad70fe14a..ac8a32f65 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp @@ -13,6 +13,7 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/common/string.hpp" #include "duckdb/common/queue.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/parser/expression/list.hpp" #include "duckdb/common/index_map.hpp" #include "duckdb/planner/expression_iterator.hpp" @@ -40,10 +41,18 @@ static void VerifyCompressionType(ClientContext &context, optional_ptrCast(); for (auto &col : base.columns.Logical()) { auto compression_type = col.CompressionType(); - if (CompressionTypeIsDeprecated(compression_type, storage_manager)) { - throw BinderException("Can't compress using user-provided compression type '%s', that type is deprecated " - "and only has decompress support", - CompressionTypeToString(compression_type)); + auto compression_availability_result = CompressionTypeIsAvailable(compression_type, storage_manager); + if (!compression_availability_result.IsAvailable()) { + if (compression_availability_result.IsDeprecated()) { + throw BinderException( + "Can't compress using user-provided compression type '%s', that type is deprecated " + "and only has decompress support", + CompressionTypeToString(compression_type)); + } else { + throw BinderException( + "Can't compress using user-provided compression type '%s', that type is not available yet", + CompressionTypeToString(compression_type)); + } } auto logical_type = col.GetType(); if (logical_type.id() == LogicalTypeId::USER && logical_type.HasAlias()) { @@ -289,7 +298,7 @@ void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { col.SetType(bound_expression->return_type); // Update the type in the binding, for future expansions - table_binding->types[i.index] = col.Type(); + table_binding->SetColumnType(i.index, col.Type()); } bound_indices.insert(i); } @@ -673,7 +682,7 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptrdependencies.VerifyDependencies(schema.catalog, result->Base().table); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_delete.cpp b/src/duckdb/src/planner/binder/statement/bind_delete.cpp index e83a62ae3..8820030d8 100644 --- a/src/duckdb/src/planner/binder/statement/bind_delete.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_delete.cpp @@ -5,8 +5,6 @@ #include "duckdb/planner/operator/logical_delete.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -15,38 +13,34 @@ namespace duckdb { BoundStatement Binder::Bind(DeleteStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only delete from base table!"); + auto root = std::move(bound_table.plan); + if (root->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only delete from base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - auto root = CreatePlan(*bound_table); auto &get = root->Cast(); - D_ASSERT(root->type == LogicalOperatorType::LOGICAL_GET); - + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only delete from base table"); + } + auto &table = *table_ptr; if (!table.temporary) { // delete from persistent table: not read only! auto &properties = GetStatementProperties(); - properties.RegisterDBModify(table.catalog, context); + properties.RegisterDBModify(table.catalog, context, DatabaseModificationType::DELETE_DATA); } - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - // plan any tables from the various using clauses if (!stmt.using_clauses.empty()) { unique_ptr child_operator; for (auto &using_clause : stmt.using_clauses) { // bind the using clause auto using_binder = Binder::CreateBinder(context, this); - auto bound_node = using_binder->Bind(*using_clause); - auto op = CreatePlan(*bound_node); + auto op = using_binder->Bind(*using_clause); if (child_operator) { // already bound a child: create a cross product to unify the two - child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op)); + child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op.plan)); } else { - child_operator = std::move(op); + child_operator = std::move(op.plan); } bind_context.AddContext(std::move(using_binder->bind_context)); } @@ -90,7 +84,7 @@ BoundStatement Binder::Bind(DeleteStatement &stmt) { result.types = {LogicalType::BIGINT}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; diff --git a/src/duckdb/src/planner/binder/statement/bind_detach.cpp b/src/duckdb/src/planner/binder/statement/bind_detach.cpp index 98db58055..b2d3313f5 100644 --- a/src/duckdb/src/planner/binder/statement/bind_detach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_detach.cpp @@ -13,7 +13,7 @@ BoundStatement Binder::Bind(DetachStatement &stmt) { result.types = {LogicalType::BOOLEAN}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_drop.cpp b/src/duckdb/src/planner/binder/statement/bind_drop.cpp index f40a86c61..a3d56bb2d 100644 --- a/src/duckdb/src/planner/binder/statement/bind_drop.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_drop.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/statement/drop_statement.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/standard_entry.hpp" @@ -24,7 +23,7 @@ BoundStatement Binder::Bind(DropStatement &stmt) { case CatalogType::SCHEMA_ENTRY: { // dropping a schema is never read-only because there are no temporary schemas auto &catalog = Catalog::GetCatalog(context, stmt.info->catalog); - properties.RegisterDBModify(catalog, context); + properties.RegisterDBModify(catalog, context, DatabaseModificationType::DROP_CATALOG_ENTRY); break; } case CatalogType::VIEW_ENTRY: @@ -77,7 +76,7 @@ BoundStatement Binder::Bind(DropStatement &stmt) { stmt.info->catalog = entry->ParentCatalog().GetName(); if (!entry->temporary) { // we can only drop temporary schema entries in read-only mode - properties.RegisterDBModify(entry->ParentCatalog(), context); + properties.RegisterDBModify(entry->ParentCatalog(), context, DatabaseModificationType::DROP_CATALOG_ENTRY); } stmt.info->schema = entry->ParentSchema().name; break; @@ -94,7 +93,7 @@ BoundStatement Binder::Bind(DropStatement &stmt) { result.names = {"Success"}; result.types = {LogicalType::BOOLEAN}; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp index cceb6796c..1202b01fa 100644 --- a/src/duckdb/src/planner/binder/statement/bind_execute.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -79,7 +79,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy()); rebound_plan = std::move(prepared_planner.plan); D_ASSERT(prepared->properties.bound_all_parameters); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; } // copy the properties of the prepared statement into the planner auto &properties = GetStatementProperties(); diff --git a/src/duckdb/src/planner/binder/statement/bind_export.cpp b/src/duckdb/src/planner/binder/statement/bind_export.cpp index 20d2606fe..0e6f63020 100644 --- a/src/duckdb/src/planner/binder/statement/bind_export.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_export.cpp @@ -302,7 +302,7 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { result.plan = std::move(export_node); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_extension.cpp b/src/duckdb/src/planner/binder/statement/bind_extension.cpp index b4fc0e86b..6569315f7 100644 --- a/src/duckdb/src/planner/binder/statement/bind_extension.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_extension.cpp @@ -5,8 +5,6 @@ namespace duckdb { BoundStatement Binder::Bind(ExtensionStatement &stmt) { - BoundStatement result; - // perform the planning of the function D_ASSERT(stmt.extension.plan_function); auto parse_result = @@ -18,11 +16,9 @@ BoundStatement Binder::Bind(ExtensionStatement &stmt) { properties.return_type = parse_result.return_type; // create the plan as a scan of the given table function - result.plan = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); + auto result = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); D_ASSERT(result.plan->type == LogicalOperatorType::LOGICAL_GET); auto &get = result.plan->Cast(); - result.names = get.names; - result.types = get.returned_types; get.ClearColumnIds(); for (idx_t i = 0; i < get.returned_types.size(); i++) { get.AddColumnId(i); diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index f2c8db644..97b25f89e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -22,9 +22,6 @@ #include "duckdb/planner/expression/bound_default_expression.hpp" #include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/parser/tableref/basetableref.hpp" @@ -99,7 +96,6 @@ void DoUpdateSetQualify(unique_ptr &expr, const string &table_ void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, vector> &lambda_params) { - for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { DoUpdateSetQualify(child, table_name, lambda_params); @@ -141,7 +137,6 @@ void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &tabl void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, vector> &lambda_params) { - // We avoid ambiguity with EXCLUDED columns by qualifying all column references. switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { @@ -277,7 +272,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, auto storage_info = table.GetStorageInfo(context); auto &columns = table.GetColumns(); // set up the columns on which to join - vector distinct_on_columns; + vector> all_distinct_on_columns; if (on_conflict_info.indexed_columns.empty()) { // When omitting the conflict target, we derive the join columns from the primary key/unique constraints // traverse the primary key/unique constraints @@ -292,6 +287,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, vector> and_children; auto &indexed_columns = index.column_set; + vector distinct_on_columns; for (auto &column : columns.Physical()) { if (!indexed_columns.count(column.Physical().index)) { continue; @@ -303,6 +299,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, and_children.push_back(std::move(new_condition)); distinct_on_columns.push_back(column.Name()); } + all_distinct_on_columns.push_back(std::move(distinct_on_columns)); if (and_children.empty()) { continue; } @@ -377,7 +374,7 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, throw BinderException("The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY " "CONSTRAINT or INDEX"); } - distinct_on_columns = on_conflict_info.indexed_columns; + all_distinct_on_columns.push_back(on_conflict_info.indexed_columns); merge_into->using_columns = std::move(on_conflict_info.indexed_columns); } @@ -445,23 +442,29 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, } } // push DISTINCT ON(unique_columns) - auto distinct_stmt = make_uniq(); - auto select_node = make_uniq(); - auto distinct = make_uniq(); - for (auto &col : distinct_on_columns) { - distinct->distinct_on_targets.push_back(make_uniq(col)); + for (auto &distinct_on_columns : all_distinct_on_columns) { + auto distinct_stmt = make_uniq(); + auto select_node = make_uniq(); + auto distinct = make_uniq(); + for (auto &col : distinct_on_columns) { + distinct->distinct_on_targets.push_back(make_uniq(col)); + } + select_node->modifiers.push_back(std::move(distinct)); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(source); + distinct_stmt->node = std::move(select_node); + source = make_uniq(std::move(distinct_stmt), "excluded"); } - select_node->modifiers.push_back(std::move(distinct)); - select_node->select_list.push_back(make_uniq()); - select_node->from_table = std::move(source); - distinct_stmt->node = std::move(select_node); - source = make_uniq(std::move(distinct_stmt), "excluded"); merge_into->source = std::move(source); if (on_conflict_info.action_type == OnConflictAction::REPLACE) { D_ASSERT(!on_conflict_info.set_info); - on_conflict_info.set_info = CreateSetInfoForReplace(table, stmt, storage_info); + // For BY POSITION, create explicit SET information + // For BY NAME, leave it empty and let bind_merge_into handle it automatically + if (stmt.column_order != InsertColumnOrder::INSERT_BY_NAME) { + on_conflict_info.set_info = CreateSetInfoForReplace(table, stmt, storage_info); + } on_conflict_info.action_type = OnConflictAction::UPDATE; } // now set up the merge actions @@ -480,16 +483,19 @@ unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, // when doing UPDATE set up the when matched action auto update_action = make_uniq(); update_action->action_type = MergeActionType::MERGE_UPDATE; - for (auto &col : on_conflict_info.set_info->expressions) { - vector> lambda_params; - DoUpdateSetQualify(col, table_name, lambda_params); - } - if (on_conflict_info.set_info->condition) { - vector> lambda_params; - DoUpdateSetQualify(on_conflict_info.set_info->condition, table_name, lambda_params); - update_action->condition = std::move(on_conflict_info.set_info->condition); + update_action->column_order = stmt.column_order; + if (on_conflict_info.set_info) { + for (auto &col : on_conflict_info.set_info->expressions) { + vector> lambda_params; + DoUpdateSetQualify(col, table_name, lambda_params); + } + if (on_conflict_info.set_info->condition) { + vector> lambda_params; + DoUpdateSetQualify(on_conflict_info.set_info->condition, table_name, lambda_params); + update_action->condition = std::move(on_conflict_info.set_info->condition); + } + update_action->update_info = std::move(on_conflict_info.set_info); } - update_action->update_info = std::move(on_conflict_info.set_info); merge_into->actions[MergeActionCondition::WHEN_MATCHED].push_back(std::move(update_action)); } @@ -515,12 +521,15 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { if (!table.temporary) { // inserting into a non-temporary table: alters underlying database auto &properties = GetStatementProperties(); - properties.RegisterDBModify(table.catalog, context); + DatabaseModificationType modification_type = DatabaseModificationType::INSERT_DATA; + auto storage_info = table.GetStorageInfo(context); + if (!storage_info.index_info.empty()) { + modification_type = DatabaseModificationType::INSERT_DATA_WITH_INDEX; + } + properties.RegisterDBModify(table.catalog, context, modification_type); } auto insert = make_uniq(table, GenerateTableIndex()); - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); auto values_list = stmt.GetValuesList(); @@ -593,7 +602,7 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { result.plan = std::move(insert); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_load.cpp b/src/duckdb/src/planner/binder/statement/bind_load.cpp index 53d8f5792..a252716fe 100644 --- a/src/duckdb/src/planner/binder/statement/bind_load.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_load.cpp @@ -24,7 +24,7 @@ BoundStatement Binder::Bind(LoadStatement &stmt) { result.plan = make_uniq(LogicalOperatorType::LOGICAL_LOAD, std::move(stmt.info)); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::NOTHING; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp index 5b187c8e3..1fc7a188f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp @@ -26,13 +26,13 @@ BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { result.plan = std::move(stmt.plan); auto &properties = GetStatementProperties(); - properties.allow_stream_result = true; + properties.output_type = QueryResultOutputType::ALLOW_STREAMING; properties.return_type = StatementReturnType::QUERY_RESULT; // TODO could also be something else if (parent) { throw InternalException("LogicalPlanStatement should be bound in root binder"); } - bound_tables = GetMaxTableIndex(*result.plan) + 1; + global_binder_state->bound_tables = GetMaxTableIndex(*result.plan) + 1; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index 87a9726ec..dcc37e017 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -1,6 +1,5 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" @@ -41,10 +40,20 @@ unique_ptr Binder::BindMergeAction(LogicalMergeInto &merge auto result = make_uniq(); result->action_type = action.action_type; if (action.condition) { - ProjectionBinder proj_binder(*this, context, proj_index, expressions, "WHERE clause"); - proj_binder.target_type = LogicalType::BOOLEAN; - auto cond = proj_binder.Bind(action.condition); - result->condition = std::move(cond); + if (action.condition->HasSubquery()) { + // if we have a subquery we need to execute the condition outside of the MERGE INTO statement + WhereBinder where_binder(*this, context); + auto cond = where_binder.Bind(action.condition); + PlanSubqueries(cond, root); + result->condition = + make_uniq(cond->return_type, ColumnBinding(proj_index, expressions.size())); + expressions.push_back(std::move(cond)); + } else { + ProjectionBinder proj_binder(*this, context, proj_index, expressions, "WHERE clause"); + proj_binder.target_type = LogicalType::BOOLEAN; + auto cond = proj_binder.Bind(action.condition); + result->condition = std::move(cond); + } } switch (action.action_type) { case MergeActionType::MERGE_UPDATE: { @@ -173,20 +182,68 @@ void RewriteMergeBindings(LogicalOperator &op, const vector &sour op, [&](unique_ptr *child) { RewriteMergeBindings(*child, source_bindings, new_table_index); }); } +LogicalGet &ExtractLogicalGet(LogicalOperator &op) { + reference current_op(op); + while (current_op.get().type == LogicalOperatorType::LOGICAL_FILTER) { + current_op = *current_op.get().children[0]; + } + if (current_op.get().type != LogicalOperatorType::LOGICAL_GET) { + throw InvalidInputException("BindMerge - expected to find an operator of type LOGICAL_GET but got %s", + op.ToString()); + } + return current_op.get().Cast(); +} + +void CheckMergeAction(MergeActionCondition condition, MergeActionType action_type) { + if (condition == MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET) { + switch (action_type) { + case MergeActionType::MERGE_UPDATE: + case MergeActionType::MERGE_DELETE: + throw ParserException("WHEN NOT MATCHED (BY TARGET) cannot be combined with UPDATE or DELETE actions - as " + "there is no corresponding row in the target to update or delete.\nDid you mean to " + "use WHEN MATCHED or WHEN NOT MATCHED BY SOURCE?"); + default: + break; + } + } +} + BoundStatement Binder::Bind(MergeIntoStatement &stmt) { // bind the target table auto target_binder = Binder::CreateBinder(context, this); string table_alias = stmt.target->alias; auto bound_table = target_binder->Bind(*stmt.target); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only merge into base tables!"); + } + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { throw BinderException("Can only merge into base tables!"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto &table = *table_ptr; if (!table.temporary) { // update of persistent table: not read only! auto &properties = GetStatementProperties(); - properties.RegisterDBModify(table.catalog, context); + // modification type depends on actions + DatabaseModificationType modification; + for (auto &action_condition : stmt.actions) { + for (auto &action : action_condition.second) { + switch (action->action_type) { + case MergeActionType::MERGE_UPDATE: + modification |= DatabaseModificationType::UPDATE_DATA; + break; + case MergeActionType::MERGE_DELETE: + modification |= DatabaseModificationType::DELETE_DATA; + break; + case MergeActionType::MERGE_INSERT: + modification |= DatabaseModificationType::INSERT_DATA; + break; + default: + break; + } + } + } + properties.RegisterDBModify(table.catalog, context, modification); } // bind the source @@ -198,9 +255,10 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { vector source_names; for (auto &binding_entry : source_binder->bind_context.GetBindingsList()) { auto &binding = *binding_entry; - for (idx_t c = 0; c < binding.names.size(); c++) { - source_aliases.push_back(binding.alias); - source_names.push_back(binding.names[c]); + auto &column_names = binding.GetColumnNames(); + for (idx_t c = 0; c < column_names.size(); c++) { + source_aliases.push_back(binding.GetBindingAlias()); + source_names.push_back(column_names[c]); } } @@ -231,11 +289,19 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { } auto bound_join_node = Bind(join); - auto root = CreatePlan(*bound_join_node); + auto root = std::move(bound_join_node.plan); + auto join_ref = reference(*root); + while (join_ref.get().children.size() == 1) { + join_ref = *join_ref.get().children[0]; + } + if (join_ref.get().children.size() != 2) { + throw NotImplementedException("Expected a join after binding a join operator - but got a %s", + join_ref.get().type); + } // kind of hacky, CreatePlan turns a RIGHT join into a LEFT join so the children get reversed from what we need bool inverted = join.type == JoinType::RIGHT; - auto &source = root->children[inverted ? 1 : 0]; - auto &get = root->children[inverted ? 0 : 1]->Cast(); + auto &source = join_ref.get().children[inverted ? 1 : 0]; + auto &get = ExtractLogicalGet(*join_ref.get().children[inverted ? 0 : 1]); auto merge_into = make_uniq(table); merge_into->table_index = GenerateTableIndex(); @@ -257,6 +323,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { for (auto &entry : stmt.actions) { vector> bound_actions; for (auto &action : entry.second) { + CheckMergeAction(entry.first, action->action_type); bound_actions.push_back(BindMergeAction(*merge_into, table, get, proj_index, projection_expressions, root, *action, source_aliases, source_names)); } @@ -327,7 +394,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { result.types = {LogicalType::BIGINT}; auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp index 3955cf897..b5fc04677 100644 --- a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/planner/operator/logical_pragma.hpp" #include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" @@ -28,16 +29,32 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte } // bind the pragma function - auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); + auto entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + // try to find whether a table extry might exist + auto table_entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, + info.name, OnEntryNotFound::RETURN_NULL); + if (table_entry) { + // there is a table entry with the same name, now throw more explicit error message + throw CatalogException("Pragma Function with name %s does not exist, but a table function with the same " + "name exists, try `CALL %s(...)`", + info.name, info.name); + } + // rebind to throw exception + entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::THROW_EXCEPTION); + } + FunctionBinder function_binder(*this); ErrorData error; - auto bound_idx = function_binder.BindFunction(entry.name, entry.functions, params, error); + auto bound_idx = function_binder.BindFunction(entry->name, entry->functions, params, error); if (!bound_idx.IsValid()) { D_ASSERT(error.HasError()); error.AddQueryLocation(error_context); error.Throw(); } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx.GetIndex()); + auto bound_function = entry->functions.GetFunctionByOffset(bound_idx.GetIndex()); // bind and check named params BindNamedParameters(bound_function.named_parameters, named_parameters, error_context, bound_function.name); return make_uniq(std::move(bound_function), std::move(params), std::move(named_parameters)); diff --git a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp index cbb338dfc..4c0579726 100644 --- a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp @@ -8,7 +8,7 @@ namespace duckdb { BoundStatement Binder::Bind(PrepareStatement &stmt) { Planner prepared_planner(context); auto prepared_data = prepared_planner.PrepareSQLStatement(std::move(stmt.statement)); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; if (prepared_planner.properties.always_require_rebind) { // we always need to rebind - don't keep the plan around @@ -20,7 +20,7 @@ BoundStatement Binder::Bind(PrepareStatement &stmt) { // this is required because most clients ALWAYS invoke prepared statements auto &properties = GetStatementProperties(); properties.requires_valid_transaction = false; - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.bound_all_parameters = true; properties.parameter_count = 0; properties.return_type = StatementReturnType::NOTHING; diff --git a/src/duckdb/src/planner/binder/statement/bind_select.cpp b/src/duckdb/src/planner/binder/statement/bind_select.cpp index ee68d0e25..a2656d076 100644 --- a/src/duckdb/src/planner/binder/statement/bind_select.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_select.cpp @@ -6,7 +6,7 @@ namespace duckdb { BoundStatement Binder::Bind(SelectStatement &stmt) { auto &properties = GetStatementProperties(); - properties.allow_stream_result = true; + properties.output_type = QueryResultOutputType::ALLOW_STREAMING; properties.return_type = StatementReturnType::QUERY_RESULT; return Bind(*stmt.node); } diff --git a/src/duckdb/src/planner/binder/statement/bind_simple.cpp b/src/duckdb/src/planner/binder/statement/bind_simple.cpp index 942f6784c..2baf55e73 100644 --- a/src/duckdb/src/planner/binder/statement/bind_simple.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_simple.cpp @@ -60,16 +60,15 @@ BoundStatement Binder::BindAlterAddIndex(BoundStatement &result, CatalogEntry &e TableDescription table_description(table_info.catalog, table_info.schema, table_info.name); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("can only add an index to a base table"); } - auto plan = CreatePlan(*bound_table); - auto &get = plan->Cast(); + auto &get = bound_table.plan->Cast(); get.names = column_list.GetColumnNames(); auto alter_table_info = unique_ptr_cast(std::move(alter_info)); - result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(plan), std::move(create_index_info), - std::move(alter_table_info)); + result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(bound_table.plan), + std::move(create_index_info), std::move(alter_table_info)); return std::move(result); } @@ -82,7 +81,7 @@ BoundStatement Binder::Bind(AlterStatement &stmt) { if (stmt.info->type == AlterType::ALTER_DATABASE) { auto &properties = GetStatementProperties(); properties.return_type = StatementReturnType::NOTHING; - properties.RegisterDBModify(Catalog::GetSystemCatalog(context), context); + properties.RegisterDBModify(Catalog::GetSystemCatalog(context), context, DatabaseModificationType::ALTER_TABLE); result.plan = make_uniq(LogicalOperatorType::LOGICAL_ALTER, std::move(stmt.info)); return result; } @@ -115,7 +114,7 @@ BoundStatement Binder::Bind(AlterStatement &stmt) { } if (!entry->temporary) { // We can only alter temporary tables and views in read-only mode. - properties.RegisterDBModify(catalog, context); + properties.RegisterDBModify(catalog, context, DatabaseModificationType::ALTER_TABLE); } stmt.info->catalog = catalog.GetName(); stmt.info->schema = entry->ParentSchema().name; diff --git a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp index 45b2b2f25..f8a68ae4c 100644 --- a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp @@ -9,7 +9,6 @@ #include "duckdb/parser/tableref/showref.hpp" #include "duckdb/parser/tableref/basetableref.hpp" #include "duckdb/parser/expression/star_expression.hpp" -#include "duckdb/planner/bound_tableref.hpp" namespace duckdb { @@ -78,7 +77,7 @@ static unique_ptr SummarizeCreateNullPercentage(string column_ return make_uniq(LogicalType::DECIMAL(9, 2), std::move(case_expr)); } -unique_ptr Binder::BindSummarize(ShowRef &ref) { +BoundStatement Binder::BindSummarize(ShowRef &ref) { unique_ptr query; if (ref.query) { query = std::move(ref.query); diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 650b23b89..1480000c6 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -2,7 +2,6 @@ #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/constraints/bound_check_constraint.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_default_expression.hpp" @@ -12,7 +11,6 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/storage/data_table.hpp" @@ -110,14 +108,15 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only update base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); + auto &bound_table_get = bound_table.plan->Cast(); + auto table_ptr = bound_table_get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only update base table"); + } + auto &table = *table_ptr; optional_ptr get; if (stmt.from_table) { @@ -129,14 +128,14 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { get = &root->children[0]->Cast(); bind_context.AddContext(std::move(from_binder->bind_context)); } else { - root = CreatePlan(*bound_table); + root = std::move(bound_table.plan); get = &root->Cast(); } if (!table.temporary) { // update of persistent table: not read only! auto &properties = GetStatementProperties(); - properties.RegisterDBModify(table.catalog, context); + properties.RegisterDBModify(table.catalog, context, DatabaseModificationType::UPDATE_DATA); } auto update = make_uniq(table); @@ -192,7 +191,7 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { result.plan = std::move(update); auto &properties = GetStatementProperties(); - properties.allow_stream_result = false; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; properties.return_type = StatementReturnType::CHANGED_ROWS; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp index 93e70fe5b..026f682b0 100644 --- a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp @@ -15,12 +15,18 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } D_ASSERT(vacuum.column_id_map.empty()); + auto bound_table = Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw InvalidInputException("can only vacuum or analyze base tables"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only vacuum or analyze base tables"); + } + auto table_scan = std::move(bound_table.plan); + auto &get = table_scan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; vacuum.SetTable(table); vector> select_list; @@ -60,11 +66,6 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } info.columns = std::move(non_generated_column_names); - auto table_scan = CreatePlan(*ref); - D_ASSERT(table_scan->type == LogicalOperatorType::LOGICAL_GET); - - auto &get = table_scan->Cast(); - auto &column_ids = get.GetColumnIds(); D_ASSERT(select_list.size() == column_ids.size()); D_ASSERT(info.columns.size() == column_ids.size()); diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index da1dacb15..ac19b0a81 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -11,15 +11,13 @@ #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/planner/tableref/bound_at_clause.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { @@ -48,10 +46,10 @@ static bool TryLoadExtensionForReplacementScan(ClientContext &context, const str return false; } -unique_ptr Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { +BoundStatement Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { auto &config = DBConfig::GetConfig(context); if (!context.config.use_replacement_scans) { - return nullptr; + return BoundStatement(); } for (auto &scan : config.replacement_scans) { ReplacementScanInput input(ref.catalog_name, ref.schema_name, ref.table_name); @@ -73,14 +71,21 @@ unique_ptr Binder::BindWithReplacementScan(ClientContext &context auto &subquery = replacement_function->Cast(); subquery.column_name_alias = ref.column_name_alias; } else { - throw InternalException("Replacement scan should return either a table function or a subquery"); + auto select_node = make_uniq(); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(replacement_function); + auto select_stmt = make_uniq(); + select_stmt->node = std::move(select_node); + auto subquery = make_uniq(std::move(select_stmt)); + subquery->column_name_alias = ref.column_name_alias; + replacement_function = std::move(subquery); } if (GetBindingMode() == BindingMode::EXTRACT_REPLACEMENT_SCANS) { AddReplacementScan(ref.table_name, replacement_function->Copy()); } return Bind(*replacement_function); } - return nullptr; + return BoundStatement(); } unique_ptr Binder::BindAtClause(optional_ptr at_clause) { @@ -116,62 +121,36 @@ static vector ExchangeAllNullTypes(const vector &types return result; } -unique_ptr Binder::Bind(BaseTableRef &ref) { +BoundStatement Binder::Bind(BaseTableRef &ref) { QueryErrorContext error_context(ref.query_location); // CTEs and views are also referred to using BaseTableRefs, hence need to distinguish here // check if the table name refers to a CTE // CTE name should never be qualified (i.e. schema_name should be empty) // unless we want to refer to the recurring table of "using key". - vector> found_ctes; - if (ref.schema_name.empty() || ref.schema_name == "recurring") { - found_ctes = FindCTE(ref.table_name, false); - } - - if (!found_ctes.empty()) { - // Check if there is a CTE binding in the BindContext - auto ctebinding = bind_context.GetCTEBinding(ref.table_name); - if (ctebinding) { - // There is a CTE binding in the BindContext. - // This can only be the case if there is a recursive CTE, - // or a materialized CTE present. - auto index = GenerateTableIndex(); - - if (ref.schema_name == "recurring") { - auto recurring_bindings = FindCTE("recurring." + ref.table_name, false); - if (recurring_bindings.empty()) { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query." - " Hint: RECURRING can only be used with USING KEY in recursive CTE.", - ref.table_name); - } - } - - auto result = make_uniq(index, ctebinding->index, ref.schema_name == "recurring"); - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - auto names = BindContext::AliasColumnNames(alias, ctebinding->names, ref.column_name_alias); + BindingAlias binding_alias(ref.schema_name, ref.table_name); + auto ctebinding = GetCTEBinding(binding_alias); + if (ctebinding && ctebinding->CanBeReferenced()) { + ctebinding->Reference(); - bind_context.AddGenericBinding(index, alias, names, ctebinding->types); + // There is a CTE binding in the BindContext. + // This can only be the case if there is a recursive CTE, + // or a materialized CTE present. + auto index = GenerateTableIndex(); - auto cte_reference = ref.schema_name.empty() ? ref.table_name : ref.schema_name + "." + ref.table_name; + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + auto names = BindContext::AliasColumnNames(alias, ctebinding->GetColumnNames(), ref.column_name_alias); - // Update references to CTE - auto cteref = bind_context.cte_references[cte_reference]; - - if (cteref == nullptr && ref.schema_name == "recurring") { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query.", - ref.table_name); - } + bind_context.AddGenericBinding(index, alias, names, ctebinding->GetColumnTypes()); - (*cteref)++; + bool is_recurring = ref.schema_name == "recurring"; - result->types = ctebinding->types; - result->bound_columns = std::move(names); - return std::move(result); - } + BoundStatement result; + result.types = ctebinding->GetColumnTypes(); + result.names = names; + result.plan = + make_uniq(index, ctebinding->GetIndex(), result.types, std::move(names), is_recurring); + return result; } // not a CTE @@ -198,14 +177,19 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { vector types {LogicalType::INTEGER}; vector names {"__dummy_col" + to_string(table_index)}; bind_context.AddGenericBinding(table_index, ref_alias, names, types); - return make_uniq_base(table_index); + + BoundStatement result; + result.types = std::move(types); + result.names = std::move(names); + result.plan = make_uniq(table_index); + return result; } } if (!table_or_view) { // table could not be found: try to bind a replacement scan // Try replacement scan bind auto replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } @@ -214,7 +198,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { auto extension_loaded = TryLoadExtensionForReplacementScan(context, full_path); if (extension_loaded) { replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } } @@ -222,7 +206,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { if (context.config.use_replacement_scans && config.options.enable_external_access && ExtensionHelper::IsFullPath(full_path)) { auto &fs = FileSystem::GetFileSystem(context); - if (fs.FileExists(full_path)) { + if (!fs.IsDisabledForPath(full_path) && fs.FileExists(full_path)) { throw BinderException( "No extension found that is capable of reading the file \"%s\"\n* If this file is a supported file " "format you can explicitly use the reader functions, such as read_csv, read_json or read_parquet", @@ -230,17 +214,13 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { } } - // remember that we did not find a CTE, but there is a CTE with the same name - // this means that there is a circular reference - // Otherwise, re-throw the original exception - if (found_ctes.empty() && ref.schema_name.empty() && CTEExists(ref.table_name)) { - throw BinderException( - error_context, - "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " - "use recursive CTEs. \n2. If " - "you want to use the TABLE name \"%s\" the same as the CTE name, please explicitly add " - "\"SCHEMA\" before table name. You can try \"main.%s\" (main is the duckdb default schema)", - ref.table_name, ref.table_name, ref.table_name); + // if we found a CTE that cannot be referenced that means that there is a circular reference + if (ctebinding) { + D_ASSERT(!ctebinding->CanBeReferenced()); + throw BinderException(error_context, + "Circular reference to CTE \"%s\", use WITH RECURSIVE to " + "use recursive CTEs.", + ref.table_name); } // could not find an alternative: bind again to get the error // note: this will always throw when using DuckDB as a catalog, but a second look-up might succeed @@ -251,7 +231,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { switch (table_or_view->type) { case CatalogType::TABLE_ENTRY: { - // base table: create the BoundBaseTableRef node + // base table auto table_index = GenerateTableIndex(); auto &table = table_or_view->Cast(); @@ -294,7 +274,11 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { } else { bind_context.AddBaseTable(table_index, ref.alias, table_names, table_types, col_ids, *table_entry); } - return make_uniq_base(table, std::move(logical_get)); + BoundStatement result; + result.types = table_types; + result.names = table_names; + result.plan = std::move(logical_get); + return result; } case CatalogType::VIEW_ENTRY: { // the node is a view: get the query that the view represents @@ -307,29 +291,6 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { // The view may contain CTEs, but maybe only in the cte_map, so we need create CTE nodes for them auto query = view_catalog_entry.GetQuery().Copy(); - auto &select_stmt = query->Cast(); - - vector> materialized_ctes; - for (auto &cte : select_stmt.node->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt.node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - select_stmt.node = std::move(root); - SubqueryRef subquery(unique_ptr_cast(std::move(query))); subquery.alias = ref.alias; @@ -355,15 +316,13 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { throw BinderException("Contents of view were altered - view bound correlated columns"); } - D_ASSERT(bound_child->type == TableReferenceType::SUBQUERY); // verify that the types and names match up with the expected types and names if the view has type info defined - auto &bound_subquery = bound_child->Cast(); if (GetBindingMode() != BindingMode::EXTRACT_NAMES && GetBindingMode() != BindingMode::EXTRACT_QUALIFIED_NAMES && view_catalog_entry.HasTypes()) { // we bind the view subquery and the original view with different "can_contain_nulls", // but we don't want to throw an error when SQLNULL does not match up with INTEGER, // so we exchange all SQLNULL with INTEGER here before comparing - auto bound_types = ExchangeAllNullTypes(bound_subquery.subquery->types); + auto bound_types = ExchangeAllNullTypes(bound_child.types); auto view_types = ExchangeAllNullTypes(view_catalog_entry.types); if (bound_types != view_types) { auto actual_types = StringUtil::ToString(bound_types, ", "); @@ -372,17 +331,17 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { "Contents of view were altered: types don't match! Expected [%s], but found [%s] instead", expected_types, actual_types); } - if (bound_subquery.subquery->names.size() == view_catalog_entry.names.size() && - bound_subquery.subquery->names != view_catalog_entry.names) { - auto actual_names = StringUtil::Join(bound_subquery.subquery->names, ", "); + if (bound_child.names.size() == view_catalog_entry.names.size() && + bound_child.names != view_catalog_entry.names) { + auto actual_names = StringUtil::Join(bound_child.names, ", "); auto expected_names = StringUtil::Join(view_catalog_entry.names, ", "); throw BinderException( "Contents of view were altered: names don't match! Expected [%s], but found [%s] instead", expected_names, actual_names); } } - bind_context.AddView(bound_subquery.subquery->GetRootIndex(), subquery.alias, subquery, - *bound_subquery.subquery, view_catalog_entry); + bind_context.AddView(bound_child.plan->GetRootIndex(), subquery.alias, subquery, bound_child, + view_catalog_entry); return bound_child; } default: diff --git a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp index e31c2e83c..ace531ccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp @@ -2,8 +2,8 @@ namespace duckdb { -unique_ptr Binder::Bind(BoundRefWrapper &ref) { - if (!ref.binder || !ref.bound_ref) { +BoundStatement Binder::Bind(BoundRefWrapper &ref) { + if (!ref.binder || !ref.bound_ref.plan) { throw InternalException("Rebinding bound ref that was already bound"); } bind_context.AddContext(std::move(ref.binder->bind_context)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp index 635d23f71..d3c5ea4a2 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp @@ -1,20 +1,25 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/tableref/column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(ColumnDataRef &ref) { +BoundStatement Binder::Bind(ColumnDataRef &ref) { auto &collection = *ref.collection; auto types = collection.Types(); - auto result = make_uniq(std::move(ref.collection)); - result->bind_index = GenerateTableIndex(); - for (idx_t i = ref.expected_names.size(); i < types.size(); i++) { - ref.expected_names.push_back("col" + to_string(i + 1)); + + BoundStatement result; + result.names = std::move(ref.expected_names); + for (idx_t i = result.names.size(); i < types.size(); i++) { + result.names.push_back("col" + to_string(i + 1)); } - bind_context.AddGenericBinding(result->bind_index, ref.alias, ref.expected_names, types); - return unique_ptr_cast(std::move(result)); + result.types = types; + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, ref.alias, result.names, types); + + result.plan = + make_uniq_base(bind_index, std::move(types), std::move(ref.collection)); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp index f280404f9..18c27cccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp @@ -1,16 +1,21 @@ #include "duckdb/parser/tableref/delimgetref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(DelimGetRef &ref) { +BoundStatement Binder::Bind(DelimGetRef &ref) { // Have to add bindings idx_t tbl_idx = GenerateTableIndex(); string internal_name = "__internal_delim_get_ref_" + std::to_string(tbl_idx); - bind_context.AddGenericBinding(tbl_idx, internal_name, ref.internal_aliases, ref.types); - return make_uniq(tbl_idx, ref.types); + BoundStatement result; + result.types = std::move(ref.types); + result.names = std::move(ref.internal_aliases); + result.plan = make_uniq(tbl_idx, result.types); + + bind_context.AddGenericBinding(tbl_idx, internal_name, result.names, result.types); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp index fe0e96f3d..b6ea93ab8 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp @@ -1,11 +1,13 @@ #include "duckdb/parser/tableref/emptytableref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(EmptyTableRef &ref) { - return make_uniq(GenerateTableIndex()); +BoundStatement Binder::Bind(EmptyTableRef &ref) { + BoundStatement result; + result.plan = make_uniq(GenerateTableIndex()); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp index 7176fb682..139f94670 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -1,72 +1,87 @@ #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" #include "duckdb/planner/expression_binder/insert_binder.hpp" #include "duckdb/common/to_string.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(ExpressionListRef &expr) { - auto result = make_uniq(); - result->types = expr.expected_types; - result->names = expr.expected_names; +BoundStatement Binder::Bind(ExpressionListRef &expr) { + BoundStatement result; + result.types = expr.expected_types; + result.names = expr.expected_names; + + vector>> values; auto prev_can_contain_nulls = this->can_contain_nulls; // bind value list InsertBinder binder(*this, context); binder.target_type = LogicalType(LogicalTypeId::INVALID); for (idx_t list_idx = 0; list_idx < expr.values.size(); list_idx++) { auto &expression_list = expr.values[list_idx]; - if (result->names.empty()) { + if (result.names.empty()) { // no names provided, generate them for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - result->names.push_back("col" + to_string(val_idx)); + result.names.push_back("col" + to_string(val_idx)); } } this->can_contain_nulls = true; vector> list; for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - if (!result->types.empty()) { - D_ASSERT(result->types.size() == expression_list.size()); - binder.target_type = result->types[val_idx]; + if (!result.types.empty()) { + D_ASSERT(result.types.size() == expression_list.size()); + binder.target_type = result.types[val_idx]; } auto bound_expr = binder.Bind(expression_list[val_idx]); list.push_back(std::move(bound_expr)); } - result->values.push_back(std::move(list)); + values.push_back(std::move(list)); this->can_contain_nulls = prev_can_contain_nulls; } - if (result->types.empty() && !expr.values.empty()) { + if (result.types.empty() && !expr.values.empty()) { // there are no types specified // we have to figure out the result types // for each column, we iterate over all of the expressions and select the max logical type // we initialize all types to SQLNULL - result->types.resize(expr.values[0].size(), LogicalType::SQLNULL); + result.types.resize(expr.values[0].size(), LogicalType::SQLNULL); // now loop over the lists and select the max logical type - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { - auto ¤t_type = result->types[val_idx]; + auto ¤t_type = result.types[val_idx]; auto next_type = ExpressionBinder::GetExpressionReturnType(*list[val_idx]); - result->types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); + result.types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); } } - for (auto &type : result->types) { + for (auto &type : result.types) { type = LogicalType::NormalizeType(type); } // finally do another loop over the expressions and add casts where required - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { list[val_idx] = - BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result->types[val_idx]); + BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result.types[val_idx]); } } } - result->bind_index = GenerateTableIndex(); - bind_context.AddGenericBinding(result->bind_index, expr.alias, result->names, result->types); - return std::move(result); + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, expr.alias, result.names, result.types); + + // values list, first plan any subqueries in the list + auto root = make_uniq_base(GenerateTableIndex()); + for (auto &expr_list : values) { + for (auto &expr : expr_list) { + PlanSubqueries(expr, root); + } + } + + auto expr_get = make_uniq(bind_index, result.types, std::move(values)); + expr_get->AddChild(std::move(root)); + result.plan = std::move(expr_get); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp index 257e275be..0a6420bfd 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -55,7 +55,7 @@ bool Binder::TryFindBinding(const string &using_column, const string &join_side, } throw BinderException(error); } else { - result = binding.get().alias; + result = binding.get().GetBindingAlias(); } } return true; @@ -122,14 +122,14 @@ static vector RemoveDuplicateUsingColumns(const vector &using_co return result; } -unique_ptr Binder::BindJoin(Binder &parent_binder, TableRef &ref) { +BoundStatement Binder::BindJoin(Binder &parent_binder, TableRef &ref) { unnamed_subquery_index = parent_binder.unnamed_subquery_index; auto result = Bind(ref); parent_binder.unnamed_subquery_index = unnamed_subquery_index; return result; } -unique_ptr Binder::Bind(JoinRef &ref) { +BoundStatement Binder::Bind(JoinRef &ref) { auto result = make_uniq(ref.ref_type); result->left_binder = Binder::CreateBinder(context, this); result->right_binder = Binder::CreateBinder(context, this); @@ -188,7 +188,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { case_insensitive_set_t lhs_columns; auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); for (auto &binding : lhs_binding_list) { - for (auto &column_name : binding->names) { + for (auto &column_name : binding->GetColumnNames()) { lhs_columns.insert(column_name); } } @@ -215,7 +215,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { auto &rhs_binding_list = right_binder.bind_context.GetBindingsList(); for (auto &binding_ref : lhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!left_candidates.empty()) { left_candidates += ", "; } @@ -224,7 +224,7 @@ unique_ptr Binder::Bind(JoinRef &ref) { } for (auto &binding_ref : rhs_binding_list) { auto &binding = *binding_ref; - for (auto &column_name : binding.names) { + for (auto &column_name : binding.GetColumnNames()) { if (!right_candidates.empty()) { right_candidates += ", "; } @@ -351,7 +351,13 @@ unique_ptr Binder::Bind(JoinRef &ref) { bind_context.RemoveContext(left_bindings); } - return std::move(result); + BoundStatement result_stmt; + result_stmt.types.insert(result_stmt.types.end(), result->left.types.begin(), result->left.types.end()); + result_stmt.types.insert(result_stmt.types.end(), result->right.types.begin(), result->right.types.end()); + result_stmt.names.insert(result_stmt.names.end(), result->left.names.begin(), result->left.names.end()); + result_stmt.names.insert(result_stmt.names.end(), result->right.names.begin(), result->right.names.end()); + result_stmt.plan = CreatePlan(*result); + return result_stmt; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index 2eb211530..869676a89 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -9,18 +9,18 @@ #include "duckdb/parser/expression/conjunction_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" #include "duckdb/common/types/value_map.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/expression/operator_expression.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/main/client_config.hpp" #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/main/query_result.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -58,10 +58,15 @@ static void ConstructPivots(PivotRef &ref, vector &pivot_valu } } -static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns) { +static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns, + optional_ptr macro_binding) { ParsedExpressionIterator::VisitExpression( root_expr, [&](const ColumnRefExpression &child_colref) { if (child_colref.IsQualified()) { + if (child_colref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && macro_binding && + macro_binding->HasMatchingBinding(child_colref.GetName())) { + throw ParameterNotResolvedException(); + } throw BinderException(child_colref, "PIVOT expression cannot contain qualified columns"); } handled_columns.insert(child_colref.GetColumnName()); @@ -378,24 +383,23 @@ static unique_ptr PivotFinalOperator(PivotBindState &bind_state, Piv return final_pivot_operator; } -void ExtractPivotAggregates(BoundTableRef &node, vector> &aggregates) { - if (node.type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected a subquery"); - } - auto &subq = node.Cast(); - if (subq.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected a select node"); - } - auto &select = subq.subquery->Cast(); - if (select.from_table->type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected another subquery"); - } - auto &subq2 = select.from_table->Cast(); - if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected another select node"); +void ExtractPivotAggregates(BoundStatement &node, vector> &aggregates) { + reference op(*node.plan); + bool found_first_aggregate = false; + while (true) { + if (op.get().type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + if (found_first_aggregate) { + break; + } + found_first_aggregate = true; + } + if (op.get().children.size() != 1) { + throw InternalException("Pivot - expected an aggregate"); + } + op = *op.get().children[0]; } - auto &select2 = subq2.subquery->Cast(); - for (auto &aggr : select2.aggregates) { + auto &aggr_op = op.get().Cast(); + for (auto &aggr : aggr_op.expressions) { if (aggr->GetAlias() == "__collated_group") { continue; } @@ -412,15 +416,15 @@ string GetPivotAggregateName(const PivotValueElement &pivot_value, const string return name; } -unique_ptr Binder::BindBoundPivot(PivotRef &ref) { +BoundStatement Binder::BindBoundPivot(PivotRef &ref) { // bind the child table in a child binder - auto result = make_uniq(); - result->bind_index = GenerateTableIndex(); - result->child_binder = Binder::CreateBinder(context, this); - result->child = result->child_binder->Bind(*ref.source); + BoundPivotRef result; + result.bind_index = GenerateTableIndex(); + result.child_binder = Binder::CreateBinder(context, this); + result.child = result.child_binder->Bind(*ref.source); - auto &aggregates = result->bound_pivot.aggregates; - ExtractPivotAggregates(*result->child, aggregates); + auto &aggregates = result.bound_pivot.aggregates; + ExtractPivotAggregates(result.child, aggregates); if (aggregates.size() != ref.bound_aggregate_names.size()) { throw InternalException("Pivot aggregate count mismatch (expected %llu, found %llu)", ref.bound_aggregate_names.size(), aggregates.size()); @@ -428,7 +432,7 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { vector child_names; vector child_types; - result->child_binder->bind_context.GetTypesAndNames(child_names, child_types); + result.child_binder->bind_context.GetTypesAndNames(child_names, child_types); vector names; vector types; @@ -453,19 +457,23 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { pivot_str += "_" + str; } } - result->bound_pivot.pivot_values.push_back(std::move(pivot_str)); + result.bound_pivot.pivot_values.push_back(std::move(pivot_str)); names.push_back(std::move(name)); types.push_back(aggr->return_type); } } - result->bound_pivot.group_count = ref.bound_group_names.size(); - result->bound_pivot.types = types; + result.bound_pivot.group_count = ref.bound_group_names.size(); + result.bound_pivot.types = types; auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; QueryResult::DeduplicateColumns(names); - bind_context.AddGenericBinding(result->bind_index, subquery_alias, names, types); + bind_context.AddGenericBinding(result.bind_index, subquery_alias, names, types); - MoveCorrelatedExpressions(*result->child_binder); - return std::move(result); + MoveCorrelatedExpressions(*result.child_binder); + + BoundStatement result_statement; + result_statement.plan = + make_uniq(result.bind_index, std::move(result.child.plan), std::move(result.bound_pivot)); + return result_statement; } unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { @@ -492,7 +500,7 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector Binder::BindPivot(PivotRef &ref, vector Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, vector result; ExtractUnpivotColumnName(*unpivot_expr, result); if (result.empty()) { - throw BinderException( *unpivot_expr, "UNPIVOT clause must contain exactly one column - expression \"%s\" does not contain any", @@ -827,7 +834,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, return result_node; } -unique_ptr Binder::Bind(PivotRef &ref) { +BoundStatement Binder::Bind(PivotRef &ref) { if (!ref.source) { throw InternalException("Pivot without a source!?"); } @@ -858,13 +865,10 @@ unique_ptr Binder::Bind(PivotRef &ref) { } // bind the generated select node auto child_binder = Binder::CreateBinder(context, this); - auto bound_select_node = child_binder->BindNode(*select_node); - auto root_index = bound_select_node->GetRootIndex(); - BoundQueryNode *bound_select_ptr = bound_select_node.get(); + auto result = child_binder->BindNode(*select_node); + auto root_index = result.plan->GetRootIndex(); - unique_ptr result; MoveCorrelatedExpressions(*child_binder); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; SubqueryRef subquery_ref(nullptr, subquery_alias); subquery_ref.column_name_alias = std::move(ref.column_name_alias); @@ -872,16 +876,14 @@ unique_ptr Binder::Bind(PivotRef &ref) { // if a WHERE clause was provided - bind a subquery holding the WHERE clause // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest child_binder = Binder::CreateBinder(context, this); - child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); auto where_query = make_uniq(); where_query->select_list.push_back(make_uniq()); where_query->where_clause = std::move(where_clause); - bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); - bound_select_ptr = bound_select_node.get(); - root_index = bound_select_node->GetRootIndex(); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + result = child_binder->BindSelectNode(*where_query, std::move(result)); + root_index = result.plan->GetRootIndex(); } - bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); return result; } diff --git a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp index b23456cab..d2d91c3af 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp @@ -5,12 +5,10 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/main/client_context.hpp" @@ -89,7 +87,7 @@ BaseTableColumnInfo FindBaseTableColumn(LogicalOperator &op, idx_t column_index) return FindBaseTableColumn(op, bindings[column_index]); } -unique_ptr Binder::BindShowQuery(ShowRef &ref) { +BoundStatement Binder::BindShowQuery(ShowRef &ref) { // bind the child plan of the DESCRIBE statement auto child_binder = Binder::CreateBinder(context, this); auto plan = child_binder->Bind(*ref.query); @@ -142,12 +140,17 @@ unique_ptr Binder::BindShowQuery(ShowRef &ref) { } collection->Append(append_state, output); - auto show = make_uniq(GenerateTableIndex(), return_types, std::move(collection)); - bind_context.AddGenericBinding(show->table_index, "__show_select", return_names, return_types); - return make_uniq(std::move(show)); + auto table_index = GenerateTableIndex(); + + BoundStatement result; + result.names = return_names; + result.types = return_types; + result.plan = make_uniq(table_index, return_types, std::move(collection)); + bind_context.AddGenericBinding(table_index, "__show_select", return_names, return_types); + return result; } -unique_ptr Binder::BindShowTable(ShowRef &ref) { +BoundStatement Binder::BindShowTable(ShowRef &ref) { auto lname = StringUtil::Lower(ref.table_name); string sql; @@ -193,7 +196,7 @@ unique_ptr Binder::BindShowTable(ShowRef &ref) { return Bind(*subquery); } -unique_ptr Binder::Bind(ShowRef &ref) { +BoundStatement Binder::Bind(ShowRef &ref) { if (ref.show_type == ShowType::SUMMARY) { return BindSummarize(ref); } diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp index 9eed0ea61..cfa727927 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -1,15 +1,14 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { -unique_ptr Binder::Bind(SubqueryRef &ref) { +BoundStatement Binder::Bind(SubqueryRef &ref) { auto binder = Binder::CreateBinder(context, this); binder->can_contain_nulls = true; auto subquery = binder->BindNode(*ref.subquery->node); binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; - idx_t bind_index = subquery->GetRootIndex(); + idx_t bind_index = subquery.plan->GetRootIndex(); string subquery_alias; if (ref.alias.empty()) { auto index = unnamed_subquery_index++; @@ -21,10 +20,14 @@ unique_ptr Binder::Bind(SubqueryRef &ref) { } else { subquery_alias = ref.alias; } - auto result = make_uniq(std::move(binder), std::move(subquery)); - bind_context.AddSubquery(bind_index, subquery_alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + binder->is_outside_flattened = is_outside_flattened; + if (binder->has_unplanned_dependent_joins) { + has_unplanned_dependent_joins = true; + } + bind_context.AddSubquery(bind_index, subquery_alias, ref, subquery); + MoveCorrelatedExpressions(*binder); + + return subquery; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index 0c6e1e0aa..528478c58 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -13,9 +13,6 @@ #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -79,32 +76,28 @@ static TableFunctionBindType GetTableFunctionBindType(TableFunctionCatalogEntry : TableFunctionBindType::STANDARD_TABLE_FUNCTION; } -void Binder::BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery) { +void Binder::BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery) { auto binder = Binder::CreateBinder(this->context, this); - unique_ptr subquery_node; // generate a subquery and bind that (i.e. UNNEST([1,2,3]) becomes UNNEST((SELECT [1,2,3])) auto select_node = make_uniq(); select_node->select_list = std::move(expressions); select_node->from_table = make_uniq(); - subquery_node = std::move(select_node); binder->can_contain_nulls = true; - auto node = binder->BindNode(*subquery_node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*select_node); + MoveCorrelatedExpressions(*binder); } bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, - named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error) { + named_parameter_map_t &named_parameters, BoundStatement &subquery, + ErrorData &error) { auto bind_type = GetTableFunctionBindType(table_function, expressions); if (bind_type == TableFunctionBindType::TABLE_IN_OUT_FUNCTION) { // bind table in-out function BindTableInTableOutFunction(expressions, subquery); // fetch the arguments from the subquery - arguments = subquery->subquery->types; + arguments = subquery.types; return true; } bool seen_subquery = false; @@ -142,12 +135,11 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi auto binder = Binder::CreateBinder(this->context, this); binder->can_contain_nulls = true; auto &se = child->Cast(); - auto node = binder->BindNode(*se.subquery->node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*se.subquery->node); + MoveCorrelatedExpressions(*binder); seen_subquery = true; arguments.emplace_back(LogicalTypeId::TABLE); - parameters.emplace_back(Value()); + parameters.emplace_back(); continue; } @@ -188,11 +180,10 @@ static string GetAlias(const TableFunctionRef &ref) { return string(); } -unique_ptr Binder::BindTableFunctionInternal(TableFunction &table_function, - const TableFunctionRef &ref, vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names) { +BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, + vector input_table_names) { auto function_name = GetAlias(ref); auto &column_name_alias = ref.column_name_alias; auto bind_index = GenerateTableIndex(); @@ -221,8 +212,12 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab table_function.name); } } + BoundStatement result; bind_context.AddGenericBinding(bind_index, function_name, return_names, new_plan->types); - return new_plan; + result.names = return_names; + result.types = new_plan->types; + result.plan = std::move(new_plan); + return result; } } if (table_function.bind_replace) { @@ -234,7 +229,7 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab if (!ref.column_name_alias.empty()) { new_plan->column_name_alias = ref.column_name_alias; } - return CreatePlan(*Bind(*new_plan)); + return Bind(*new_plan); } } if (!table_function.bind) { @@ -307,52 +302,46 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab } if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY && correlated_columns.empty()) { + bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), + get->GetTable().get(), std::move(virtual_columns)); + auto window_index = GenerateTableIndex(); auto window = make_uniq(window_index); auto row_number = make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + string ordinality_alias = ordinality_column_name; if (return_names.size() < column_name_alias.size()) { row_number->alias = column_name_alias[return_names.size()]; + ordinality_alias = column_name_alias[return_names.size()]; } else { row_number->alias = ordinality_column_name; } + return_names.push_back(ordinality_alias); + return_types.push_back(LogicalType::BIGINT); window->expressions.push_back(std::move(row_number)); - for (idx_t i = 0; i < return_types.size(); i++) { - get->AddColumnId(i); - } + window->types.push_back(LogicalType::BIGINT); window->children.push_back(std::move(get)); + bind_context.AddGenericBinding(window_index, function_name, {ordinality_alias}, {LogicalType::BIGINT}); - vector> select_list; - for (idx_t i = 0; i < return_types.size(); i++) { - auto expression = make_uniq(return_types[i], ColumnBinding(bind_index, i)); - select_list.push_back(std::move(expression)); - } - select_list.push_back(make_uniq(LogicalType::BIGINT, ColumnBinding(window_index, 0))); - - auto projection_index = GenerateTableIndex(); - auto projection = make_uniq(projection_index, std::move(select_list)); - - projection->children.push_back(std::move(window)); - if (return_names.size() < column_name_alias.size()) { - return_names.push_back(column_name_alias[return_names.size()]); - } else { - return_names.push_back(ordinality_column_name); - } - - return_types.push_back(LogicalType::BIGINT); - bind_context.AddGenericBinding(projection_index, function_name, return_names, return_types); - return std::move(projection); + BoundStatement result; + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(window); + return result; } - // now add the table function to the bind context so its columns can be bound + BoundStatement result; bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); - return std::move(get); + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(get); + return result; } -unique_ptr Binder::BindTableFunction(TableFunction &function, vector parameters) { +BoundStatement Binder::BindTableFunction(TableFunction &function, vector parameters) { named_parameter_map_t named_parameters; vector input_table_types; vector input_table_names; @@ -364,7 +353,7 @@ unique_ptr Binder::BindTableFunction(TableFunction &function, v std::move(input_table_types), std::move(input_table_names)); } -unique_ptr Binder::Bind(TableFunctionRef &ref) { +BoundStatement Binder::Bind(TableFunctionRef &ref) { QueryErrorContext error_context(ref.query_location); D_ASSERT(ref.function->GetExpressionType() == ExpressionType::FUNCTION); @@ -388,7 +377,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { binder->can_contain_nulls = true; binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; - unique_ptr query; + BoundStatement query; try { query = binder->BindNode(*query_node); } catch (std::exception &ex) { @@ -397,15 +386,14 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { error.Throw(); } - idx_t bind_index = query->GetRootIndex(); + idx_t bind_index = query.plan->GetRootIndex(); // string alias; string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); - auto result = make_uniq(std::move(binder), std::move(query)); // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, *result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + bind_context.AddSubquery(bind_index, alias, ref, query); + MoveCorrelatedExpressions(*binder); + return query; } D_ASSERT(func_catalog.type == CatalogType::TABLE_FUNCTION_ENTRY); auto &function = func_catalog.Cast(); @@ -414,7 +402,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector arguments; vector parameters; named_parameter_map_t named_parameters; - unique_ptr subquery; + BoundStatement subquery; ErrorData error; if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, error)) { @@ -437,9 +425,9 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector input_table_types; vector input_table_names; - if (subquery) { - input_table_types = subquery->subquery->types; - input_table_names = subquery->subquery->names; + if (subquery.plan) { + input_table_types = subquery.types; + input_table_names = subquery.names; } else if (table_function.in_out_function) { for (auto ¶m : parameters) { input_table_types.push_back(param.type()); @@ -457,7 +445,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { parameters[i] = parameters[i].CastAs(context, target_type); } } - } else if (subquery) { + } else if (subquery.plan) { for (idx_t i = 0; i < arguments.size(); i++) { auto target_type = i < table_function.arguments.size() ? table_function.arguments[i] : table_function.varargs; @@ -469,11 +457,39 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { } } - auto get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), - std::move(input_table_types), std::move(input_table_names)); - auto table_function_ref = make_uniq(std::move(get)); - table_function_ref->subquery = std::move(subquery); - return std::move(table_function_ref); + BoundStatement get; + try { + get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), + std::move(input_table_types), std::move(input_table_names)); + } catch (std::exception &ex) { + error = ErrorData(ex); + error.AddQueryLocation(ref); + error.Throw(); + } + + if (subquery.plan) { + auto child_node = std::move(subquery.plan); + + reference node = *get.plan; + + while (!node.get().children.empty()) { + D_ASSERT(node.get().children.size() == 1); + if (node.get().children.size() != 1) { + throw InternalException( + "Binder::CreatePlan: linear path expected, but found node with %d children", + node.get().children.size()); + } + node = *node.get().children[0]; + } + + D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); + node.get().children.push_back(std::move(child_node)); + } + BoundStatement result_statement; + result_statement.names = get.names; + result_statement.types = get.types; + result_statement.plan = std::move(get.plan); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp deleted file mode 100644 index 085498fbb..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundBaseTableRef &ref) { - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp deleted file mode 100644 index 83e965b5e..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/operator/logical_column_data_get.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundColumnDataRef &ref) { - auto types = ref.collection->Types(); - // Create a (potentially owning) LogicalColumnDataGet - auto root = make_uniq_base(ref.bind_index, std::move(types), - std::move(ref.collection)); - return root; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp b/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp deleted file mode 100644 index 4ee2b9a76..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_cteref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTERef &ref) { - return make_uniq(ref.bind_index, ref.cte_index, ref.types, ref.bound_columns, ref.is_recurring); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp deleted file mode 100644 index b674b43df..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/operator/logical_delim_get.hpp" -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundDelimGetRef &ref) { - return make_uniq(ref.bind_index, ref.column_types); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp deleted file mode 100644 index f31fc929b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundEmptyTableRef &ref) { - return make_uniq(ref.bind_index); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp deleted file mode 100644 index ba6253bce..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" -#include "duckdb/planner/operator/logical_expression_get.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundExpressionListRef &ref) { - auto root = make_uniq_base(GenerateTableIndex()); - // values list, first plan any subqueries in the list - for (auto &expr_list : ref.values) { - for (auto &expr : expr_list) { - PlanSubqueries(expr, root); - } - } - // now create a LogicalExpressionGet from the set of expressions - // fetch the types - vector types; - for (auto &expr : ref.values[0]) { - types.push_back(expr->return_type); - } - auto expr_get = make_uniq(ref.bind_index, types, std::move(ref.values)); - expr_get->AddChild(std::move(root)); - return std::move(expr_get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp index 9de5829f2..7891c501b 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp @@ -19,6 +19,90 @@ namespace duckdb { +//! Check if a filter can be safely pushed to the left child +//! This is used ONLY for join conditions in the ON clause, not for WHERE clause filters. +//! The logic determines whether a condition that references only the left side can be +//! pushed down as a filter on the left child operator. +static bool CanPushToLeftChild(JoinType type) { + switch (type) { + case JoinType::INNER: + case JoinType::SEMI: + case JoinType::RIGHT: + return true; + case JoinType::ANTI: + case JoinType::LEFT: + case JoinType::OUTER: + return false; + default: + return false; + } +} + +//! Check if a filter can be safely pushed to the right child +//! This is used ONLY for join conditions in the ON clause, not for WHERE clause filters. +//! The logic determines whether a condition that references only the right side can be +//! pushed down as a filter on the right child operator. +static bool CanPushToRightChild(JoinType type) { + switch (type) { + case JoinType::INNER: + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::LEFT: + return true; + case JoinType::RIGHT: + case JoinType::OUTER: + return false; + default: + return false; + } +} + +//! Push a filter expression to a child operator +static void PushFilterToChild(unique_ptr &child, unique_ptr &expr) { + if (child->type != LogicalOperatorType::LOGICAL_FILTER) { + auto filter = make_uniq(); + filter->AddChild(std::move(child)); + child = std::move(filter); + } + + auto &filter = child->Cast(); + filter.expressions.push_back(std::move(expr)); +} + +//! Check if a foldable expression evaluates to TRUE and can be eliminated +static bool CanEliminate(ClientContext &context, JoinType type, unique_ptr &expr) { + if (!expr->IsFoldable()) { + return false; + } + + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(context, *expr, result)) { + return false; + } + + if (result.IsNull()) { + return false; + } + + bool is_true = (result == Value(true)); + + if (is_true) { + switch (type) { + case JoinType::INNER: + case JoinType::LEFT: + case JoinType::RIGHT: + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::OUTER: + return true; + default: + return false; + } + } + + return false; +} + //! Only use conditions that are valid for the join ref type static bool IsJoinTypeCondition(const JoinRefType ref_type, const ExpressionType expr_type) { switch (ref_type) { @@ -39,6 +123,23 @@ static bool IsJoinTypeCondition(const JoinRefType ref_type, const ExpressionType } } +//! Check an expression is a usable comparison expression +static bool IsComparisonExpression(const Expression &expr) { + switch (expr.GetExpressionType()) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + case ExpressionType::COMPARE_DISTINCT_FROM: + return true; + default: + return false; + } +} + //! Create a JoinCondition from a comparison static bool CreateJoinCondition(Expression &expr, const unordered_set &left_bindings, const unordered_set &right_bindings, vector &conditions) { @@ -65,57 +166,36 @@ static bool CreateJoinCondition(Expression &expr, const unordered_set &le return false; } +//! Extract join conditions, pushing single-side filters to children when it's safe void LogicalComparisonJoin::ExtractJoinConditions( ClientContext &context, JoinType type, JoinRefType ref_type, unique_ptr &left_child, unique_ptr &right_child, const unordered_set &left_bindings, const unordered_set &right_bindings, vector> &expressions, vector &conditions, vector> &arbitrary_expressions) { - for (auto &expr : expressions) { - auto total_side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings); - if (total_side != JoinSide::BOTH) { - // join condition does not reference both sides, add it as filter under the join - if ((type == JoinType::LEFT || ref_type == JoinRefType::ASOF) && total_side == JoinSide::RIGHT) { - // filter is on RHS and the join is a LEFT OUTER join, we can push it in the right child - if (right_child->type != LogicalOperatorType::LOGICAL_FILTER) { - // not a filter yet, push a new empty filter - auto filter = make_uniq(); - filter->AddChild(std::move(right_child)); - right_child = std::move(filter); - } - // push the expression into the filter - auto &filter = right_child->Cast(); - filter.expressions.push_back(std::move(expr)); + auto side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings); + + if (side == JoinSide::NONE) { + if (CanEliminate(context, type, expr)) { continue; } - // if the join is a LEFT JOIN and the join expression constantly evaluates to TRUE, - // then we do not add it to the arbitrary expressions - if (type == JoinType::LEFT && expr->IsFoldable()) { - Value result; - ExpressionExecutor::TryEvaluateScalar(context, *expr, result); - if (!result.IsNull() && result == Value(true)) { - continue; - } + } else if (side == JoinSide::LEFT) { + if (CanPushToLeftChild(type)) { + PushFilterToChild(left_child, expr); + continue; + } + } else if (side == JoinSide::RIGHT) { + if (CanPushToRightChild(type)) { + PushFilterToChild(right_child, expr); + continue; } - } else if (expr->GetExpressionType() == ExpressionType::COMPARE_EQUAL || - expr->GetExpressionType() == ExpressionType::COMPARE_NOTEQUAL || - expr->GetExpressionType() == ExpressionType::COMPARE_BOUNDARY_START || - expr->GetExpressionType() == ExpressionType::COMPARE_LESSTHAN || - expr->GetExpressionType() == ExpressionType::COMPARE_GREATERTHAN || - expr->GetExpressionType() == ExpressionType::COMPARE_LESSTHANOREQUALTO || - expr->GetExpressionType() == ExpressionType::COMPARE_GREATERTHANOREQUALTO || - expr->GetExpressionType() == ExpressionType::COMPARE_BOUNDARY_START || - expr->GetExpressionType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM || - expr->GetExpressionType() == ExpressionType::COMPARE_DISTINCT_FROM) - - { - // comparison, check if we can create a comparison JoinCondition - if (IsJoinTypeCondition(ref_type, expr->GetExpressionType()) && + } else if (side == JoinSide::BOTH) { + if (IsComparisonExpression(*expr) && IsJoinTypeCondition(ref_type, expr->GetExpressionType()) && CreateJoinCondition(*expr, left_bindings, right_bindings, conditions)) { - // successfully created the join condition continue; } } + arbitrary_expressions.push_back(std::move(expr)); } } @@ -146,18 +226,17 @@ void LogicalComparisonJoin::ExtractJoinConditions(ClientContext &context, JoinTy arbitrary_expressions); } -unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &context, JoinType type, - JoinRefType reftype, +//! Create the join operator based on conditions and join type +unique_ptr LogicalComparisonJoin::CreateJoin(JoinType type, JoinRefType ref_type, unique_ptr left_child, unique_ptr right_child, vector conditions, vector> arbitrary_expressions) { - // Validate the conditions - bool need_to_consider_arbitrary_expressions = true; - const bool is_asof = reftype == JoinRefType::ASOF; + const bool is_asof = ref_type == JoinRefType::ASOF; + + // validate ASOF join conditions if (is_asof) { - // Handle case of zero conditions - auto asof_idx = conditions.size() + 1; + idx_t asof_idx = conditions.size(); for (size_t c = 0; c < conditions.size(); ++c) { auto &cond = conditions[c]; switch (cond.comparison) { @@ -182,89 +261,89 @@ unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &con } } - if (type == JoinType::INNER && reftype == JoinRefType::REGULAR) { - // for inner joins we can push arbitrary expressions as a filter - // here we prefer to create a comparison join if possible - // that way we can use the much faster hash join to process the main join - // rather than doing a nested loop join to handle arbitrary expressions + // what type of join to create now? + // Case 1: ASOF join - use comparison join + if (is_asof) { + auto asof_join = make_uniq(type, LogicalOperatorType::LOGICAL_ASOF_JOIN); + asof_join->conditions = std::move(conditions); + asof_join->children.push_back(std::move(left_child)); + asof_join->children.push_back(std::move(right_child)); - // for left and full outer joins we HAVE to process all join conditions - // because pushing a filter will lead to an incorrect result, as non-matching tuples cannot be filtered out - need_to_consider_arbitrary_expressions = false; - } - if ((need_to_consider_arbitrary_expressions && !arbitrary_expressions.empty()) || conditions.empty()) { - if (is_asof) { - D_ASSERT(!conditions.empty()); - // We still need to produce an ASOF join here, but it will have to evaluate the arbitrary conditions itself - auto asof_join = make_uniq(type, LogicalOperatorType::LOGICAL_ASOF_JOIN); - asof_join->conditions = std::move(conditions); - asof_join->children.push_back(std::move(left_child)); - asof_join->children.push_back(std::move(right_child)); - // AND all the arbitrary expressions together + if (!arbitrary_expressions.empty()) { asof_join->predicate = std::move(arbitrary_expressions[0]); for (idx_t i = 1; i < arbitrary_expressions.size(); i++) { asof_join->predicate = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(asof_join->predicate), std::move(arbitrary_expressions[i])); } - return std::move(asof_join); } + return std::move(asof_join); + } + + // Case 2: No join conditions - use any join + if (conditions.empty()) { if (arbitrary_expressions.empty()) { - // all conditions were pushed down, add TRUE predicate arbitrary_expressions.push_back(make_uniq(Value::BOOLEAN(true))); } - // if we get here we could not create any JoinConditions - // turn this into an arbitrary expression join + auto any_join = make_uniq(type); - // create the condition any_join->children.push_back(std::move(left_child)); any_join->children.push_back(std::move(right_child)); - // AND all the arbitrary expressions together - // do the same with any remaining conditions - idx_t start_idx = 0; - if (conditions.empty()) { - // no conditions, just use the arbitrary expressions - any_join->condition = std::move(arbitrary_expressions[0]); - start_idx = 1; - } else { - // we have some conditions as well - any_join->condition = JoinCondition::CreateExpression(std::move(conditions[0])); - for (idx_t i = 1; i < conditions.size(); i++) { - any_join->condition = make_uniq( - ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), - JoinCondition::CreateExpression(std::move(conditions[i]))); - } - } - for (idx_t i = start_idx; i < arbitrary_expressions.size(); i++) { + + any_join->condition = std::move(arbitrary_expressions[0]); + for (idx_t i = 1; i < arbitrary_expressions.size(); i++) { any_join->condition = make_uniq( ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), std::move(arbitrary_expressions[i])); } + return std::move(any_join); - } else { - // we successfully converted expressions into JoinConditions - // create a LogicalComparisonJoin - auto logical_type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; - if (is_asof) { - logical_type = LogicalOperatorType::LOGICAL_ASOF_JOIN; - } - auto comp_join = make_uniq(type, logical_type); - comp_join->conditions = std::move(conditions); - comp_join->children.push_back(std::move(left_child)); - comp_join->children.push_back(std::move(right_child)); - if (!arbitrary_expressions.empty()) { - // we have some arbitrary expressions as well - // add them to a filter + } + + // Case 3: Has join conditions and arbitrary expressions - decide based on join type + if (!arbitrary_expressions.empty()) { + // for inner join create comparison join + filter on top + if (type == JoinType::INNER) { + auto comp_join = make_uniq(type, LogicalOperatorType::LOGICAL_COMPARISON_JOIN); + comp_join->conditions = std::move(conditions); + comp_join->children.push_back(std::move(left_child)); + comp_join->children.push_back(std::move(right_child)); + auto filter = make_uniq(); for (auto &expr : arbitrary_expressions) { filter->expressions.push_back(std::move(expr)); } - LogicalFilter::SplitPredicates(filter->expressions); filter->children.push_back(std::move(comp_join)); + return std::move(filter); + } else { + auto any_join = make_uniq(type); + any_join->children.push_back(std::move(left_child)); + any_join->children.push_back(std::move(right_child)); + + any_join->condition = JoinCondition::CreateExpression(std::move(conditions[0])); + for (idx_t i = 1; i < conditions.size(); i++) { + any_join->condition = make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), + JoinCondition::CreateExpression(std::move(conditions[i]))); + } + + for (auto &expr : arbitrary_expressions) { + any_join->condition = make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(any_join->condition), std::move(expr)); + } + + return std::move(any_join); } - return std::move(comp_join); } + + // Case 4: Has join conditions but not arbitrary expressions - use comparison join + auto comp_join = make_uniq(type, LogicalOperatorType::LOGICAL_COMPARISON_JOIN); + comp_join->conditions = std::move(conditions); + comp_join->children.push_back(std::move(left_child)); + comp_join->children.push_back(std::move(right_child)); + + return std::move(comp_join); } static bool HasCorrelatedColumns(const Expression &root_expr) { @@ -287,7 +366,7 @@ unique_ptr LogicalComparisonJoin::CreateJoin(ClientContext &con vector> arbitrary_expressions; LogicalComparisonJoin::ExtractJoinConditions(context, type, reftype, left_child, right_child, std::move(condition), conditions, arbitrary_expressions); - return LogicalComparisonJoin::CreateJoin(context, type, reftype, std::move(left_child), std::move(right_child), + return LogicalComparisonJoin::CreateJoin(type, reftype, std::move(left_child), std::move(right_child), std::move(conditions), std::move(arbitrary_expressions)); } @@ -298,8 +377,8 @@ unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { // Set the flag to ensure that children do not flatten before the root is_outside_flattened = false; } - auto left = CreatePlan(*ref.left); - auto right = CreatePlan(*ref.right); + auto left = std::move(ref.left.plan); + auto right = std::move(ref.right.plan); is_outside_flattened = old_is_outside_flattened; // For joins, depth of the bindings will be one higher on the right because of the lateral binder diff --git a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp b/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp deleted file mode 100644 index 4d9482e5b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/planner/operator/logical_pivot.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundPivotRef &ref) { - auto subquery = ref.child_binder->CreatePlan(*ref.child); - - auto result = make_uniq(ref.bind_index, std::move(subquery), std::move(ref.bound_pivot)); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp deleted file mode 100644 index 821654460..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { - // generate the logical plan for the subquery - // this happens separately from the current LogicalPlan generation - ref.binder->is_outside_flattened = is_outside_flattened; - auto subquery = ref.binder->CreatePlan(*ref.subquery); - if (ref.binder->has_unplanned_dependent_joins) { - has_unplanned_dependent_joins = true; - } - return subquery; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp deleted file mode 100644 index 6c2f9957a..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { - if (ref.subquery) { - auto child_node = CreatePlan(*ref.subquery); - - reference node = *ref.get; - - while (!node.get().children.empty()) { - D_ASSERT(node.get().children.size() == 1); - if (node.get().children.size() != 1) { - throw InternalException( - "Binder::CreatePlan: linear path expected, but found node with %d children", - node.get().children.size()); - } - node = *node.get().children[0]; - } - - D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); - node.get().children.push_back(std::move(child_node)); - } - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binding_alias.cpp b/src/duckdb/src/planner/binding_alias.cpp index 62f60dfa1..b80d1d393 100644 --- a/src/duckdb/src/planner/binding_alias.cpp +++ b/src/duckdb/src/planner/binding_alias.cpp @@ -1,6 +1,7 @@ #include "duckdb/planner/binding_alias.hpp" #include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { diff --git a/src/duckdb/src/planner/bound_parameter_map.cpp b/src/duckdb/src/planner/bound_parameter_map.cpp index 112a17934..4a906d188 100644 --- a/src/duckdb/src/planner/bound_parameter_map.cpp +++ b/src/duckdb/src/planner/bound_parameter_map.cpp @@ -43,7 +43,6 @@ shared_ptr BoundParameterMap::CreateOrGetData(const string & } unique_ptr BoundParameterMap::BindParameterExpression(ParameterExpression &expr) { - auto &identifier = expr.identifier; D_ASSERT(!parameter_data.count(identifier)); diff --git a/src/duckdb/src/planner/bound_result_modifier.cpp b/src/duckdb/src/planner/bound_result_modifier.cpp index edf49c4b1..4b7710bce 100644 --- a/src/duckdb/src/planner/bound_result_modifier.cpp +++ b/src/duckdb/src/planner/bound_result_modifier.cpp @@ -101,14 +101,17 @@ bool BoundOrderModifier::Equals(const unique_ptr &left, return BoundOrderModifier::Equals(*left, *right); } -bool BoundOrderModifier::Simplify(vector &orders, const vector> &groups) { +bool BoundOrderModifier::Simplify(vector &orders, const vector> &groups, + optional_ptr> grouping_sets) { // for each ORDER BY - check if it is actually necessary // expressions that are in the groups do not need to be ORDERED BY // `ORDER BY` on a group has no effect, because for each aggregate, the group is unique // similarly, we only need to ORDER BY each aggregate once + expression_map_t group_expressions; expression_set_t seen_expressions; + idx_t i = 0; for (auto &target : groups) { - seen_expressions.insert(*target); + group_expressions.insert({*target, i++}); } vector new_order_nodes; for (auto &order_node : orders) { @@ -116,16 +119,30 @@ bool BoundOrderModifier::Simplify(vector &orders, const vector // we do not need to order by this node continue; } + auto it = group_expressions.find(*order_node.expression); + bool add_to_new_order = it == group_expressions.end(); + if (!add_to_new_order && grouping_sets) { + idx_t group_idx = it->second; + for (auto &grouping_set : *grouping_sets) { + if (grouping_set.find(group_idx) == grouping_set.end()) { + add_to_new_order = true; + break; + } + } + } seen_expressions.insert(*order_node.expression); - new_order_nodes.push_back(std::move(order_node)); + if (add_to_new_order) { + new_order_nodes.push_back(std::move(order_node)); + } } orders.swap(new_order_nodes); return orders.empty(); // NOLINT } -bool BoundOrderModifier::Simplify(const vector> &groups) { - return Simplify(orders, groups); +bool BoundOrderModifier::Simplify(const vector> &groups, + optional_ptr> grouping_sets) { + return Simplify(orders, groups, grouping_sets); } BoundLimitNode::BoundLimitNode(LimitNodeType type, idx_t constant_integer, double constant_percentage, diff --git a/src/duckdb/src/planner/collation_binding.cpp b/src/duckdb/src/planner/collation_binding.cpp index 1ddefb9a8..dd371bbc4 100644 --- a/src/duckdb/src/planner/collation_binding.cpp +++ b/src/duckdb/src/planner/collation_binding.cpp @@ -8,6 +8,7 @@ #include "duckdb/function/function_binder.hpp" namespace duckdb { +constexpr const char *CollateCatalogEntry::Name; bool PushVarcharCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType type) { @@ -109,11 +110,34 @@ bool PushIntervalCollation(ClientContext &context, unique_ptr &sourc return true; } +bool PushVariantCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, + CollationType) { + if (sql_type.id() != LogicalTypeId::VARIANT) { + return false; + } + auto &catalog = Catalog::GetSystemCatalog(context); + auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "variant_normalize"); + if (function_entry.functions.Size() != 1) { + throw InternalException("variant_normalize should only have a single overload"); + } + auto source_alias = source->GetAlias(); + auto &scalar_function = function_entry.functions.GetFunctionReferenceByOffset(0); + vector> children; + children.push_back(std::move(source)); + + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); + function->SetAlias(source_alias); + source = std::move(function); + return true; +} + // timetz_byte_comparable CollationBinding::CollationBinding() { RegisterCollation(CollationCallback(PushVarcharCollation)); RegisterCollation(CollationCallback(PushTimeTZCollation)); RegisterCollation(CollationCallback(PushIntervalCollation)); + RegisterCollation(CollationCallback(PushVariantCollation)); } void CollationBinding::RegisterCollation(CollationCallback callback) { diff --git a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp index 68bc16b26..f0f30e030 100644 --- a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp @@ -11,7 +11,7 @@ namespace duckdb { BoundAggregateExpression::BoundAggregateExpression(AggregateFunction function, vector> children, unique_ptr filter, unique_ptr bind_info, AggregateType aggr_type) - : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.return_type), + : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.GetReturnType()), function(std::move(function)), children(std::move(children)), bind_info(std::move(bind_info)), aggr_type(aggr_type), filter(std::move(filter)) { D_ASSERT(!this->function.name.empty()); @@ -61,8 +61,8 @@ bool BoundAggregateExpression::Equals(const BaseExpression &other_p) const { } bool BoundAggregateExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); + return function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); } unique_ptr BoundAggregateExpression::Copy() const { diff --git a/src/duckdb/src/planner/expression/bound_cast_expression.cpp b/src/duckdb/src/planner/expression/bound_cast_expression.cpp index 9419a8db0..5d7946ec3 100644 --- a/src/duckdb/src/planner/expression/bound_cast_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_cast_expression.cpp @@ -127,6 +127,9 @@ bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) { return false; } + if (source_type.id() == LogicalTypeId::VARIANT || target_type.id() == LogicalTypeId::VARIANT) { + return false; + } if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) { uint8_t source_width, target_width; uint8_t source_scale, target_scale; diff --git a/src/duckdb/src/planner/expression/bound_function_expression.cpp b/src/duckdb/src/planner/expression/bound_function_expression.cpp index 5556dec21..dc6332531 100644 --- a/src/duckdb/src/planner/expression/bound_function_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_function_expression.cpp @@ -19,16 +19,16 @@ BoundFunctionExpression::BoundFunctionExpression(LogicalType return_type, Scalar } bool BoundFunctionExpression::IsVolatile() const { - return function.stability == FunctionStability::VOLATILE ? true : Expression::IsVolatile(); + return function.GetStability() == FunctionStability::VOLATILE ? true : Expression::IsVolatile(); } bool BoundFunctionExpression::IsConsistent() const { - return function.stability != FunctionStability::CONSISTENT ? false : Expression::IsConsistent(); + return function.GetStability() != FunctionStability::CONSISTENT ? false : Expression::IsConsistent(); } bool BoundFunctionExpression::IsFoldable() const { // functions with side effects cannot be folded: they have to be executed once for every row - if (function.bind_lambda) { + if (function.HasBindLambdaCallback()) { // This is a lambda function D_ASSERT(bind_info); auto &lambda_bind_data = bind_info->Cast(); @@ -39,11 +39,11 @@ bool BoundFunctionExpression::IsFoldable() const { } } } - return function.stability == FunctionStability::VOLATILE ? false : Expression::IsFoldable(); + return function.GetStability() == FunctionStability::VOLATILE ? false : Expression::IsFoldable(); } bool BoundFunctionExpression::CanThrow() const { - if (function.errors == FunctionErrors::CAN_THROW_RUNTIME_ERROR) { + if (function.GetErrorMode() == FunctionErrors::CAN_THROW_RUNTIME_ERROR) { return true; } return Expression::CanThrow(); @@ -54,8 +54,8 @@ string BoundFunctionExpression::ToString() const { is_operator); } bool BoundFunctionExpression::PropagatesNullValues() const { - return function.null_handling == FunctionNullHandling::SPECIAL_HANDLING ? false - : Expression::PropagatesNullValues(); + return function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING ? false + : Expression::PropagatesNullValues(); } hash_t BoundFunctionExpression::Hash() const { @@ -112,16 +112,16 @@ unique_ptr BoundFunctionExpression::Deserialize(Deserializer &deseri auto entry = FunctionSerializer::Deserialize( deserializer, CatalogType::SCALAR_FUNCTION_ENTRY, children, return_type); - auto function_return_type = entry.first.return_type; + auto function_return_type = entry.first.GetReturnType(); auto is_operator = deserializer.ReadProperty(202, "is_operator"); - if (entry.first.bind_expression) { + if (entry.first.HasBindExpressionCallback()) { // bind the function expression auto &context = deserializer.Get(); auto bind_input = FunctionBindExpressionInput(context, entry.second, children); // replace the function expression with the bound expression - auto bound_expression = entry.first.bind_expression(bind_input); + auto bound_expression = entry.first.GetBindExpressionCallback()(bind_input); if (bound_expression) { return bound_expression; } diff --git a/src/duckdb/src/planner/expression_binder.cpp b/src/duckdb/src/planner/expression_binder.cpp index 5141765bb..f78babc4c 100644 --- a/src/duckdb/src/planner/expression_binder.cpp +++ b/src/duckdb/src/planner/expression_binder.cpp @@ -1,6 +1,5 @@ #include "duckdb/planner/expression_binder.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/parser/expression/list.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/planner/binder.hpp" @@ -8,6 +7,7 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/main/client_config.hpp" +#include "duckdb/common/string_util.hpp" namespace duckdb { @@ -70,7 +70,7 @@ BindResult ExpressionBinder::BindExpression(unique_ptr &expr, case ExpressionClass::COLLATE: return BindExpression(expr_ref.Cast(), depth); case ExpressionClass::COLUMN_REF: - return BindExpression(expr_ref.Cast(), depth, root_expression); + return BindExpression(expr_ref.Cast(), depth, root_expression, expr); case ExpressionClass::LAMBDA_REF: return BindExpression(expr_ref.Cast(), depth); case ExpressionClass::COMPARISON: @@ -103,7 +103,9 @@ BindResult ExpressionBinder::BindExpression(unique_ptr &expr, case ExpressionClass::STAR: return BindResult(BinderException::Unsupported(expr_ref, "STAR expression is not supported here")); default: - throw NotImplementedException("Unimplemented expression class"); + return BindResult( + NotImplementedException("Unimplemented expression class in ExpressionBinder::BindExpression: %s", + EnumUtil::ToString(expr_ref.GetExpressionClass()))); } } @@ -164,7 +166,7 @@ static bool CombineMissingColumns(ErrorData ¤t, ErrorData new_error) { } auto score = StringUtil::SimilarityRating(candidate_column, column_name); candidates.insert(candidate); - scores.emplace_back(make_pair(std::move(candidate), score)); + scores.emplace_back(std::move(candidate), score); } // get a new top-n auto top_candidates = StringUtil::TopNStrings(scores); @@ -396,7 +398,14 @@ bool ExpressionBinder::IsUnnestFunction(const string &function_name) { return function_name == "unnest" || function_name == "unlist"; } -bool ExpressionBinder::TryBindAlias(ColumnRefExpression &colref, bool root_expression, BindResult &result) { +bool ExpressionBinder::IsPotentialAlias(const ColumnRefExpression &colref) { + // traditional alias (unqualified), or qualified with table name "alias" + if (!colref.IsQualified()) { + return true; + } + if (colref.column_names.size() == 2) { + return StringUtil::CIEquals(colref.GetTableName(), "alias"); + } return false; } diff --git a/src/duckdb/src/planner/expression_binder/check_binder.cpp b/src/duckdb/src/planner/expression_binder/check_binder.cpp index c89c96ded..c6f1abb5a 100644 --- a/src/duckdb/src/planner/expression_binder/check_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/check_binder.cpp @@ -43,7 +43,6 @@ BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref } BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { - if (!colref.IsQualified()) { if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { diff --git a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp index c4477f9e3..4d6ca6727 100644 --- a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp @@ -13,17 +13,15 @@ ColumnAliasBinder::ColumnAliasBinder(SelectBindState &bind_state) : bind_state(b bool ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, unique_ptr &expr_ptr, idx_t depth, bool root_expression, BindResult &result) { - D_ASSERT(expr_ptr->GetExpressionClass() == ExpressionClass::COLUMN_REF); auto &expr = expr_ptr->Cast(); - // Qualified columns cannot be aliases. - if (expr.IsQualified()) { + if (!ExpressionBinder::IsPotentialAlias(expr)) { return false; } // We try to find the alias in the alias_map and return false, if no alias exists. - auto alias_entry = bind_state.alias_map.find(expr.column_names[0]); + auto alias_entry = bind_state.alias_map.find(expr.column_names.back()); if (alias_entry == bind_state.alias_map.end()) { return false; } @@ -43,11 +41,11 @@ bool ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, unique_ptr return true; } -bool ColumnAliasBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { - if (!colref.IsQualified()) { - return bind_state.alias_map.find(colref.column_names[0]) != bind_state.alias_map.end(); +bool ColumnAliasBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { + if (!ExpressionBinder::IsPotentialAlias(colref)) { + return false; } - return false; + return bind_state.alias_map.find(colref.column_names[0]) != bind_state.alias_map.end(); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/constant_binder.cpp b/src/duckdb/src/planner/expression_binder/constant_binder.cpp index 97a65ba31..01f4ab11d 100644 --- a/src/duckdb/src/planner/expression_binder/constant_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/constant_binder.cpp @@ -19,7 +19,7 @@ BindResult ConstantBinder::BindExpression(unique_ptr &expr_ptr return BindExpression(expr_ptr, depth, root_expression); } } - return BindUnsupportedExpression(expr, depth, clause + " cannot contain column names"); + throw BinderException::Unsupported(expr, clause + " cannot contain column names"); } case ExpressionClass::SUBQUERY: throw BinderException(clause + " cannot contain subqueries"); diff --git a/src/duckdb/src/planner/expression_binder/group_binder.cpp b/src/duckdb/src/planner/expression_binder/group_binder.cpp index cdec41e15..975841b0d 100644 --- a/src/duckdb/src/planner/expression_binder/group_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/group_binder.cpp @@ -6,6 +6,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" #include "duckdb/common/to_string.hpp" +#include "duckdb/common/string_util.hpp" namespace duckdb { @@ -20,7 +21,7 @@ BindResult GroupBinder::BindExpression(unique_ptr &expr_ptr, i if (root_expression && depth == 0) { switch (expr.GetExpressionClass()) { case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr.Cast()); + return BindColumnRef(expr.Cast(), expr_ptr); case ExpressionClass::CONSTANT: return BindConstant(expr.Cast()); case ExpressionClass::PARAMETER: @@ -79,9 +80,12 @@ BindResult GroupBinder::BindConstant(ConstantExpression &constant) { return BindSelectRef(index - 1); } -bool GroupBinder::TryBindAlias(ColumnRefExpression &colref, bool root_expression, BindResult &result) { +bool GroupBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, + BindResult &result, unique_ptr &expr_ptr) { + // try to resolve alias references in GROUP // failed to bind the column and the node is the root expression with depth = 0 // check if refers to an alias in the select clause + auto &alias_name = colref.GetColumnName(); auto entry = bind_state.alias_map.find(alias_name); if (entry == bind_state.alias_map.end()) { @@ -102,14 +106,14 @@ bool GroupBinder::TryBindAlias(ColumnRefExpression &colref, bool root_expression return true; } -BindResult GroupBinder::BindColumnRef(ColumnRefExpression &colref) { +BindResult GroupBinder::BindColumnRef(ColumnRefExpression &colref, unique_ptr &expr_ptr) { // columns in GROUP BY clauses: // FIRST refer to the original tables, and // THEN if no match is found refer to aliases in the SELECT list // THEN if no match is found, refer to outer queries // first try to bind to the base columns (original tables) - return ExpressionBinder::BindExpression(colref, 0, true); + return ExpressionBinder::BindExpression(colref, 0, true, expr_ptr); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/having_binder.cpp b/src/duckdb/src/planner/expression_binder/having_binder.cpp index 902add5e2..ab0f11af5 100644 --- a/src/duckdb/src/planner/expression_binder/having_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/having_binder.cpp @@ -3,7 +3,6 @@ #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/window_expression.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" @@ -32,14 +31,13 @@ unique_ptr HavingBinder::QualifyColumnName(ColumnRefExpression if (group_index != DConstants::INVALID_INDEX) { return qualified_colref; } - if (column_alias_binder.QualifyColumnAlias(colref)) { + if (column_alias_binder.DoesColumnAliasExist(colref)) { return nullptr; } return qualified_colref; } BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - // Keep the original column name to return a meaningful error message. auto col_ref = expr_ptr->Cast(); const auto &column_name = col_ref.GetColumnName(); @@ -91,7 +89,7 @@ BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, i } BindResult HavingBinder::BindWindow(WindowExpression &expr, idx_t depth) { - return BindResult(BinderException::Unsupported(expr, "HAVING clause cannot contain window functions!")); + throw BinderException::Unsupported(expr, "HAVING clause cannot contain window functions!"); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp index 205b644e8..0b693558a 100644 --- a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -3,7 +3,7 @@ #include "duckdb/planner/logical_operator_visitor.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_subquery_expression.hpp" -#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" namespace duckdb { @@ -17,7 +17,7 @@ void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { // add the correlated column info CorrelatedColumnInfo info(bound_colref); if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(std::move(info)); + correlated_columns.AddColumn(std::move(info)); // TODO is adding to the front OK here? } } } @@ -54,8 +54,7 @@ string LateralBinder::UnsupportedAggregateMessage() { return "LATERAL join cannot contain aggregates!"; } -static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, - const vector &correlated_columns) { +static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, const CorrelatedColumns &correlated_columns) { // don't need to reduce this if (expr.depth == 0) { return; @@ -69,8 +68,7 @@ static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, } } -static void ReduceColumnDepth(vector &columns, - const vector &affected_columns) { +static void ReduceColumnDepth(CorrelatedColumns &columns, const CorrelatedColumns &affected_columns) { for (auto &s_correlated : columns) { for (auto &affected : affected_columns) { if (affected == s_correlated) { @@ -81,45 +79,44 @@ static void ReduceColumnDepth(vector &columns, } } -class ExpressionDepthReducerRecursive : public BoundNodeVisitor { +class ExpressionDepthReducerRecursive : public LogicalOperatorVisitor { public: - explicit ExpressionDepthReducerRecursive(const vector &correlated) - : correlated_columns(correlated) { + explicit ExpressionDepthReducerRecursive(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } - void VisitExpression(unique_ptr &expression) override { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - ReduceColumnRefDepth(expression->Cast(), correlated_columns); - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { - ReduceExpressionSubquery(expression->Cast(), correlated_columns); + void VisitExpression(unique_ptr *expression) override { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + ReduceColumnRefDepth(expr.Cast(), correlated_columns); + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { + ReduceExpressionSubquery(expr.Cast(), correlated_columns); } - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } - void VisitBoundTableRef(BoundTableRef &ref) override { - if (ref.type == TableReferenceType::JOIN) { + void VisitOperator(LogicalOperator &op) override { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); + auto &bound_join = op.Cast(); ReduceColumnDepth(bound_join.correlated_columns, correlated_columns); } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } - static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, - const vector &correlated_columns) { + static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, const CorrelatedColumns &correlated_columns) { ReduceColumnDepth(expr.binder->correlated_columns, correlated_columns); ExpressionDepthReducerRecursive recursive(correlated_columns); - recursive.VisitBoundQueryNode(*expr.subquery); + recursive.VisitOperator(*expr.subquery.plan); } private: - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; class ExpressionDepthReducer : public LogicalOperatorVisitor { public: - explicit ExpressionDepthReducer(const vector &correlated) : correlated_columns(correlated) { + explicit ExpressionDepthReducer(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } protected: @@ -133,10 +130,10 @@ class ExpressionDepthReducer : public LogicalOperatorVisitor { return nullptr; } - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; -void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector &correlated) { +void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &correlated) { ExpressionDepthReducer depth_reducer(correlated); depth_reducer.VisitOperator(op); } diff --git a/src/duckdb/src/planner/expression_binder/order_binder.cpp b/src/duckdb/src/planner/expression_binder/order_binder.cpp index 6b093eb8f..3923129fb 100644 --- a/src/duckdb/src/planner/expression_binder/order_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/order_binder.cpp @@ -76,12 +76,13 @@ optional_idx OrderBinder::TryGetProjectionReference(ParsedExpression &expr) cons } case ExpressionClass::COLUMN_REF: { auto &colref = expr.Cast(); - // if there is an explicit table name we can't bind to an alias - if (colref.IsQualified()) { + if (!ExpressionBinder::IsPotentialAlias(colref)) { break; } + + string alias_name = colref.column_names.back(); // check the alias list - auto entry = bind_state.alias_map.find(colref.column_names[0]); + auto entry = bind_state.alias_map.find(alias_name); if (entry == bind_state.alias_map.end()) { break; } diff --git a/src/duckdb/src/planner/expression_binder/select_binder.cpp b/src/duckdb/src/planner/expression_binder/select_binder.cpp index aab666a47..acbba314b 100644 --- a/src/duckdb/src/planner/expression_binder/select_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/select_binder.cpp @@ -1,5 +1,6 @@ #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" namespace duckdb { @@ -8,54 +9,51 @@ SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNo : BaseSelectBinder(binder, context, node, info) { } -unique_ptr SelectBinder::GetSQLValueFunction(const string &column_name) { - auto alias_entry = node.bind_state.alias_map.find(column_name); - if (alias_entry != node.bind_state.alias_map.end()) { - // don't replace SQL value functions if they are in the alias map - return nullptr; +bool SelectBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, + BindResult &result, unique_ptr &expr_ptr) { + // must be a qualified alias. + if (!ExpressionBinder::IsPotentialAlias(colref)) { + return false; + } + + const auto &alias_name = colref.column_names.back(); + auto entry = node.bind_state.alias_map.find(alias_name); + if (entry == node.bind_state.alias_map.end()) { + return false; } - return ExpressionBinder::GetSQLValueFunction(column_name); -} -BindResult SelectBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - // first try to bind the column reference regularly - auto result = BaseSelectBinder::BindColumnRef(expr_ptr, depth, root_expression); - if (!result.HasError()) { - return result; + auto alias_index = entry->second; + // Simple way to prevent circular aliasing (`SELECT alias.y as x, alias.x as y;`) + if (alias_index >= node.bound_column_count) { + throw BinderException("Column \"%s\" referenced that exists in the SELECT clause - but this column " + "cannot be referenced before it is defined", + colref.column_names.back()); } - // binding failed - // check in the alias map - auto &colref = (expr_ptr.get())->Cast(); - if (!colref.IsQualified()) { - auto &bind_state = node.bind_state; - auto alias_entry = node.bind_state.alias_map.find(colref.column_names[0]); - if (alias_entry != node.bind_state.alias_map.end()) { - // found entry! - auto index = alias_entry->second; - if (index >= node.bound_column_count) { - throw BinderException("Column \"%s\" referenced that exists in the SELECT clause - but this column " - "cannot be referenced before it is defined", - colref.column_names[0]); - } - if (bind_state.AliasHasSubquery(index)) { - throw BinderException("Alias \"%s\" referenced in a SELECT clause - but the expression has a subquery." - " This is not yet supported.", - colref.column_names[0]); - } - auto copied_expression = node.bind_state.BindAlias(index); - result = BindExpression(copied_expression, depth, false); - return result; - } + + if (node.bind_state.AliasHasSubquery(alias_index)) { + throw BinderException(colref, + "Alias \"%s\" referenced in a SELECT clause - but the expression has a subquery. This is " + "not yet supported.", + alias_name); } - // entry was not found in the alias map: return the original error - return result; + auto copied_unbound = node.bind_state.BindAlias(alias_index); + result = BindExpression(copied_unbound, depth, false); + return true; +} + +bool SelectBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { + // Using `back()` to support both qualified and unqualified aliasing + auto alias_name = colref.column_names.back(); + return node.bind_state.alias_map.find(alias_name) != node.bind_state.alias_map.end(); } -bool SelectBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { - if (!colref.IsQualified()) { - return node.bind_state.alias_map.find(colref.column_names[0]) != node.bind_state.alias_map.end(); +unique_ptr SelectBinder::GetSQLValueFunction(const string &column_name) { + auto alias_entry = node.bind_state.alias_map.find(column_name); + if (alias_entry != node.bind_state.alias_map.end()) { + // don't replace SQL value functions if they are in the alias map + return nullptr; } - return false; + return ExpressionBinder::GetSQLValueFunction(column_name); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index 198bd072b..16b612cd6 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -27,9 +27,13 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr if (lambda_ref) { return BindLambdaReference(lambda_ref->Cast(), depth); } + if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { throw ParameterNotResolvedException(); } + } else if (col_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && binder.macro_binding && + binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { + throw ParameterNotResolvedException(); } auto query_location = col_ref.GetQueryLocation(); @@ -51,6 +55,15 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr if (value_function) { return BindExpression(value_function, depth, root_expression); } + + auto result = BindCorrelatedColumns(expr_ptr, ErrorData("error")); + if (!result.HasError()) { + auto &bound_expr = expr_ptr->Cast(); + ExtractCorrelatedExpressions(binder, *bound_expr.expr); + result.expression = std::move(bound_expr.expr); + return result; + } + if (table_function_name.empty()) { throw BinderException(query_location, "Failed to bind \"%s\" - COLUMNS expression can only contain lambda parameters", diff --git a/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp b/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp new file mode 100644 index 000000000..5b2e4bcd1 --- /dev/null +++ b/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp @@ -0,0 +1,15 @@ +#include "duckdb/planner/expression_binder/try_operator_binder.hpp" + +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +TryOperatorBinder::TryOperatorBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context, true) { +} + +BindResult TryOperatorBinder::BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, + idx_t depth) { + throw BinderException("aggregates are not allowed inside the TRY expression"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/where_binder.cpp b/src/duckdb/src/planner/expression_binder/where_binder.cpp index 9b25c7930..21ff6c2cd 100644 --- a/src/duckdb/src/planner/expression_binder/where_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/where_binder.cpp @@ -9,22 +9,6 @@ WhereBinder::WhereBinder(Binder &binder, ClientContext &context, optional_ptr &expr_ptr, idx_t depth, bool root_expression) { - - auto result = ExpressionBinder::BindExpression(expr_ptr, depth); - if (!result.HasError() || !column_alias_binder) { - return result; - } - - BindResult alias_result; - auto found_alias = column_alias_binder->BindAlias(*this, expr_ptr, depth, root_expression, alias_result); - if (found_alias) { - return alias_result; - } - - return result; -} - BindResult WhereBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { auto &expr = *expr_ptr; switch (expr.GetExpressionClass()) { @@ -32,8 +16,6 @@ BindResult WhereBinder::BindExpression(unique_ptr &expr_ptr, i return BindUnsupportedExpression(expr, depth, "WHERE clause cannot contain DEFAULT clause"); case ExpressionClass::WINDOW: return BindUnsupportedExpression(expr, depth, "WHERE clause cannot contain window functions!"); - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr_ptr, depth, root_expression); default: return ExpressionBinder::BindExpression(expr_ptr, depth); } @@ -43,9 +25,17 @@ string WhereBinder::UnsupportedAggregateMessage() { return "WHERE clause cannot contain aggregates!"; } -bool WhereBinder::QualifyColumnAlias(const ColumnRefExpression &colref) { +bool WhereBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, + BindResult &result, unique_ptr &expr_ptr) { + if (!column_alias_binder) { + return false; + } + return column_alias_binder->BindAlias(*this, expr_ptr, depth, root_expression, result); +} + +bool WhereBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { if (column_alias_binder) { - return column_alias_binder->QualifyColumnAlias(colref); + return column_alias_binder->DoesColumnAliasExist(colref); } return false; } diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp index 042712732..3d1407900 100644 --- a/src/duckdb/src/planner/expression_iterator.cpp +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -4,8 +4,6 @@ #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/tableref/list.hpp" #include "duckdb/common/enum_util.hpp" @@ -183,156 +181,4 @@ void ExpressionIterator::VisitExpressionClassMutable( *expr, [&](unique_ptr &child) { VisitExpressionClassMutable(child, expr_class, callback); }); } -void BoundNodeVisitor::VisitExpression(unique_ptr &expression) { - VisitExpressionChildren(*expression); -} - -void BoundNodeVisitor::VisitExpressionChildren(Expression &expr) { - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { VisitExpression(expr); }); -} - -void BoundNodeVisitor::VisitBoundQueryNode(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SET_OPERATION_NODE: { - auto &bound_setop = node.Cast(); - for (auto &child : bound_setop.bound_children) { - VisitBoundQueryNode(*child.node); - } - break; - } - case QueryNodeType::RECURSIVE_CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.left); - VisitBoundQueryNode(*cte_node.right); - break; - } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.child); - VisitBoundQueryNode(*cte_node.query); - break; - } - case QueryNodeType::SELECT_NODE: { - auto &bound_select = node.Cast(); - for (auto &expr : bound_select.select_list) { - VisitExpression(expr); - } - if (bound_select.where_clause) { - VisitExpression(bound_select.where_clause); - } - for (auto &expr : bound_select.groups.group_expressions) { - VisitExpression(expr); - } - if (bound_select.having) { - VisitExpression(bound_select.having); - } - for (auto &expr : bound_select.aggregates) { - VisitExpression(expr); - } - for (auto &entry : bound_select.unnests) { - for (auto &expr : entry.second.expressions) { - VisitExpression(expr); - } - } - for (auto &expr : bound_select.windows) { - VisitExpression(expr); - } - if (bound_select.from_table) { - VisitBoundTableRef(*bound_select.from_table); - } - break; - } - default: - throw NotImplementedException("Unimplemented query node in ExpressionIterator"); - } - for (idx_t i = 0; i < node.modifiers.size(); i++) { - switch (node.modifiers[i]->type) { - case ResultModifierType::DISTINCT_MODIFIER: - for (auto &expr : node.modifiers[i]->Cast().target_distincts) { - VisitExpression(expr); - } - break; - case ResultModifierType::ORDER_MODIFIER: - for (auto &order : node.modifiers[i]->Cast().orders) { - VisitExpression(order.expression); - } - break; - case ResultModifierType::LIMIT_MODIFIER: { - auto &limit_expr = node.modifiers[i]->Cast().limit_val.GetExpression(); - auto &offset_expr = node.modifiers[i]->Cast().offset_val.GetExpression(); - if (limit_expr) { - VisitExpression(limit_expr); - } - if (offset_expr) { - VisitExpression(offset_expr); - } - break; - } - default: - break; - } - } -} - -class LogicalBoundNodeVisitor : public LogicalOperatorVisitor { -public: - explicit LogicalBoundNodeVisitor(BoundNodeVisitor &parent) : parent(parent) { - } - - void VisitExpression(unique_ptr *expression) override { - auto &expr = **expression; - parent.VisitExpression(*expression); - VisitExpressionChildren(expr); - } - -protected: - BoundNodeVisitor &parent; -}; - -void BoundNodeVisitor::VisitBoundTableRef(BoundTableRef &ref) { - switch (ref.type) { - case TableReferenceType::EXPRESSION_LIST: { - auto &bound_expr_list = ref.Cast(); - for (auto &expr_list : bound_expr_list.values) { - for (auto &expr : expr_list) { - VisitExpression(expr); - } - } - break; - } - case TableReferenceType::JOIN: { - auto &bound_join = ref.Cast(); - if (bound_join.condition) { - VisitExpression(bound_join.condition); - } - VisitBoundTableRef(*bound_join.left); - VisitBoundTableRef(*bound_join.right); - break; - } - case TableReferenceType::SUBQUERY: { - auto &bound_subquery = ref.Cast(); - VisitBoundQueryNode(*bound_subquery.subquery); - break; - } - case TableReferenceType::TABLE_FUNCTION: { - auto &bound_table_function = ref.Cast(); - LogicalBoundNodeVisitor node_visitor(*this); - if (bound_table_function.get) { - node_visitor.VisitOperator(*bound_table_function.get); - } - if (bound_table_function.subquery) { - VisitBoundTableRef(*bound_table_function.subquery); - } - break; - } - case TableReferenceType::EMPTY_FROM: - case TableReferenceType::BASE_TABLE: - case TableReferenceType::CTE: - break; - default: - throw NotImplementedException("Unimplemented table reference type (%s) in ExpressionIterator", - EnumUtil::ToString(ref.type)); - } -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/bloom_filter.cpp b/src/duckdb/src/planner/filter/bloom_filter.cpp new file mode 100644 index 000000000..6950f0264 --- /dev/null +++ b/src/duckdb/src/planner/filter/bloom_filter.cpp @@ -0,0 +1,230 @@ +#include "duckdb/planner/filter/bloom_filter.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/common/operator/subtract.hpp" + +namespace duckdb { + +static constexpr idx_t MAX_NUM_SECTORS = (1ULL << 26); +static constexpr idx_t MIN_NUM_BITS_PER_KEY = 12; +static constexpr idx_t MIN_NUM_BITS = 512; +static constexpr idx_t LOG_SECTOR_SIZE = 6; // a sector is 64 bits, log2(64) = 6 +static constexpr idx_t SHIFT_MASK = 0x3F3F3F3F3F3F3F3F; // 6 bits for 64 positions +static constexpr idx_t N_BITS = 4; // the number of bits to set per hash + +void BloomFilter::Initialize(ClientContext &context_p, idx_t number_of_rows) { + BufferManager &buffer_manager = BufferManager::GetBufferManager(context_p); + + const idx_t min_bits = std::max(MIN_NUM_BITS, number_of_rows * MIN_NUM_BITS_PER_KEY); + num_sectors = std::min(NextPowerOfTwo(min_bits) >> LOG_SECTOR_SIZE, MAX_NUM_SECTORS); + bitmask = num_sectors - 1; + + buf_ = buffer_manager.GetBufferAllocator().Allocate(64 + num_sectors * sizeof(uint64_t)); + // make sure blocks is a 64-byte aligned pointer, i.e., cache-line aligned + bf = reinterpret_cast((64ULL + reinterpret_cast(buf_.get())) & ~63ULL); + std::fill_n(bf, num_sectors, 0); + + initialized = true; +} + +inline uint64_t GetMask(const hash_t hash) { + const uint64_t shifts = hash & SHIFT_MASK; + const auto shifts_8 = reinterpret_cast(&shifts); + + uint64_t mask = 0; + + for (idx_t bit_idx = 8 - N_BITS; bit_idx < 8; bit_idx++) { + const uint8_t bit_pos = shifts_8[bit_idx]; + mask |= (1ULL << bit_pos); + } + + return mask; +} + +void BloomFilter::InsertHashes(const Vector &hashes_v, idx_t count) const { + auto hashes = FlatVector::GetData(hashes_v); + for (idx_t i = 0; i < count; i++) { + InsertOne(hashes[i]); + } +} + +idx_t BloomFilter::LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, const idx_t count) const { + D_ASSERT(hashes_v.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(hashes_v.GetType() == LogicalType::HASH); + + const auto hashes = FlatVector::GetData(hashes_v); + idx_t found_count = 0; + for (idx_t i = 0; i < count; i++) { + result_sel.set_index(found_count, i); + found_count += LookupOne(hashes[i]); + } + return found_count; +} + +inline void BloomFilter::InsertOne(const hash_t hash) const { + D_ASSERT(initialized); + const uint64_t bf_offset = hash & bitmask; + const uint64_t mask = GetMask(hash); + std::atomic &slot = *reinterpret_cast *>(&bf[bf_offset]); + + slot.fetch_or(mask, std::memory_order_relaxed); +} + +inline bool BloomFilter::LookupOne(const uint64_t hash) const { + D_ASSERT(initialized); + const uint64_t bf_offset = hash & bitmask; + const uint64_t mask = GetMask(hash); + + return (bf[bf_offset] & mask) == mask; +} + +string BFTableFilter::ToString(const string &column_name) const { + return column_name + " IN BF(" + key_column_name + ")"; +} + +void BFTableFilter::HashInternal(Vector &keys_v, const SelectionVector &sel, const idx_t approved_count, + BFTableFilterState &state) { + if (sel.IsSet()) { + state.keys_sliced_v.Slice(keys_v, sel, approved_count); + VectorOperations::Hash(state.keys_sliced_v, state.hashes_v, approved_count); + } else { + VectorOperations::Hash(keys_v, state.hashes_v, approved_count); + } +} + +idx_t BFTableFilter::Filter(Vector &keys_v, SelectionVector &sel, idx_t &approved_tuple_count, + BFTableFilterState &state) const { + if (state.current_capacity < approved_tuple_count) { + state.hashes_v.Initialize(false, approved_tuple_count); + state.bf_sel.Initialize(approved_tuple_count); + state.current_capacity = approved_tuple_count; + } + + HashInternal(keys_v, sel, approved_tuple_count, state); + + idx_t found_count; + if (state.hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) { + const auto constant_hash = *ConstantVector::GetData(state.hashes_v); + const bool found = this->filter.LookupOne(constant_hash); + found_count = found ? approved_tuple_count : 0; + } else { + state.hashes_v.Flatten(approved_tuple_count); + found_count = this->filter.LookupHashes(state.hashes_v, state.bf_sel, approved_tuple_count); + } + + // all the elements have been found, we don't need to translate anything + if (found_count == approved_tuple_count) { + return approved_tuple_count; + } + + if (sel.IsSet()) { + for (idx_t idx = 0; idx < found_count; idx++) { + const idx_t flat_sel_idx = state.bf_sel.get_index(idx); + const idx_t original_sel_idx = sel.get_index(flat_sel_idx); + sel.set_index(idx, original_sel_idx); + } + } else { + sel.Initialize(state.bf_sel); + } + + approved_tuple_count = found_count; + return approved_tuple_count; +} + +bool BFTableFilter::FilterValue(const Value &value) const { + const auto hash = value.Hash(); + return filter.LookupOne(hash); +} + +template +static FilterPropagateResult TemplatedCheckStatistics(const BloomFilter &bf, const BaseStatistics &stats) { + if (!NumericStats::HasMinMax(stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto min = NumericStats::GetMin(stats); + const auto max = NumericStats::GetMax(stats); + if (min > max) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Invalid stats + } + T range_typed; + if (!TrySubtractOperator::Operation(max, min, range_typed) || range_typed > 2048) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Overflow or too wide of a range + } + const auto range = NumericCast(range_typed); + + T val = min; + idx_t hits = 0; + for (idx_t i = 0; i <= range; i++) { + hits += bf.LookupOne(Hash(val)); + val += i < range; // Avoids potential signed integer overflow on the last iteration + } + + if (hits == 0) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (hits == range + 1) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +FilterPropagateResult BFTableFilter::CheckStatistics(BaseStatistics &stats) const { + switch (stats.GetType().InternalType()) { + case PhysicalType::UINT8: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::UINT16: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::UINT32: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::UINT64: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::UINT128: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::INT8: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::INT16: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::INT32: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::INT64: + return TemplatedCheckStatistics(filter, stats); + case PhysicalType::INT128: + return TemplatedCheckStatistics(filter, stats); + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +bool BFTableFilter::Equals(const TableFilter &other) const { + if (!TableFilter::Equals(other)) { + return false; + } + return false; +} +unique_ptr BFTableFilter::Copy() const { + return make_uniq(this->filter, this->filters_null_values, this->key_column_name, this->key_type); +} + +unique_ptr BFTableFilter::ToExpression(const Expression &column) const { + auto bound_constant = make_uniq(Value(true)); + return std::move(bound_constant); +} + +void BFTableFilter::Serialize(Serializer &serializer) const { + TableFilter::Serialize(serializer); + serializer.WriteProperty(200, "filters_null_values", filters_null_values); + serializer.WriteProperty(201, "key_column_name", key_column_name); + serializer.WriteProperty(202, "key_type", key_type); +} + +unique_ptr BFTableFilter::Deserialize(Deserializer &deserializer) { + auto filters_null_values = deserializer.ReadProperty(200, "filters_null_values"); + auto key_column_name = deserializer.ReadProperty(201, "key_column_name"); + auto key_type = deserializer.ReadProperty(202, "key_type"); + + BloomFilter filter; + auto result = make_uniq(filter, filters_null_values, key_column_name, key_type); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/constant_filter.cpp b/src/duckdb/src/planner/filter/constant_filter.cpp index 5e1f39991..be43a4b0c 100644 --- a/src/duckdb/src/planner/filter/constant_filter.cpp +++ b/src/duckdb/src/planner/filter/constant_filter.cpp @@ -57,7 +57,13 @@ FilterPropagateResult ConstantFilter::CheckStatistics(BaseStatistics &stats) con result = NumericStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); break; case PhysicalType::VARCHAR: - result = StringStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); + switch (stats.GetStatsType()) { + case StatisticsType::STRING_STATS: + result = StringStats::CheckZonemap(stats, comparison_type, array_ptr(&constant, 1)); + break; + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } break; default: return FilterPropagateResult::NO_PRUNING_POSSIBLE; diff --git a/src/duckdb/src/planner/filter/expression_filter.cpp b/src/duckdb/src/planner/filter/expression_filter.cpp index 8e9b3299f..a86433f00 100644 --- a/src/duckdb/src/planner/filter/expression_filter.cpp +++ b/src/duckdb/src/planner/filter/expression_filter.cpp @@ -27,6 +27,11 @@ bool ExpressionFilter::EvaluateWithConstant(ExpressionExecutor &executor, const } FilterPropagateResult ExpressionFilter::CheckStatistics(BaseStatistics &stats) const { + if (stats.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + // Delegate to GeometryStats for geometry types + return GeometryStats::CheckZonemap(stats, expr); + } + // we cannot prune based on arbitrary expressions currently return FilterPropagateResult::NO_PRUNING_POSSIBLE; } diff --git a/src/duckdb/src/planner/filter/in_filter.cpp b/src/duckdb/src/planner/filter/in_filter.cpp index bc9b874b6..fee680ad1 100644 --- a/src/duckdb/src/planner/filter/in_filter.cpp +++ b/src/duckdb/src/planner/filter/in_filter.cpp @@ -23,6 +23,10 @@ InFilter::InFilter(vector values_p) : TableFilter(TableFilterType::IN_FIL } FilterPropagateResult InFilter::CheckStatistics(BaseStatistics &stats) const { + if (!stats.CanHaveNoNull()) { + // no non-null values are possible: always false + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } switch (values[0].type().InternalType()) { case PhysicalType::UINT8: case PhysicalType::UINT16: diff --git a/src/duckdb/src/planner/filter/optional_filter.cpp b/src/duckdb/src/planner/filter/optional_filter.cpp index 404dac9ba..ac4e3da60 100644 --- a/src/duckdb/src/planner/filter/optional_filter.cpp +++ b/src/duckdb/src/planner/filter/optional_filter.cpp @@ -20,6 +20,12 @@ unique_ptr OptionalFilter::ToExpression(const Expression &column) co return child_filter->ToExpression(column); } +idx_t OptionalFilter::FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, + TableFilterState &filter_state, const idx_t scan_count, + idx_t &approved_tuple_count) const { + return scan_count; +} + unique_ptr OptionalFilter::Copy() const { auto copy = make_uniq(); copy->child_filter = child_filter->Copy(); diff --git a/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp b/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp new file mode 100644 index 000000000..757288b88 --- /dev/null +++ b/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp @@ -0,0 +1,114 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/selectivity_optional_filter +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" +#include "duckdb/planner/table_filter_state.hpp" + +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/function/compression/compression.hpp" + +namespace duckdb { + +constexpr float SelectivityOptionalFilter::MIN_MAX_THRESHOLD; +constexpr idx_t SelectivityOptionalFilter::MIN_MAX_CHECK_N; + +constexpr float SelectivityOptionalFilter::BF_THRESHOLD; +constexpr idx_t SelectivityOptionalFilter::BF_CHECK_N; + +SelectivityOptionalFilterState::SelectivityStats::SelectivityStats(const idx_t n_vectors_to_check, + const float selectivity_threshold) + : tuples_accepted(0), tuples_processed(0), vectors_processed(0), n_vectors_to_check(n_vectors_to_check), + selectivity_threshold(selectivity_threshold), status(FilterStatus::ACTIVE) { +} + +void SelectivityOptionalFilterState::SelectivityStats::Update(idx_t accepted, idx_t processed) { + if (vectors_processed < n_vectors_to_check) { + tuples_accepted += accepted; + tuples_processed += processed; + vectors_processed += 1; + + // pause the filter if we processed enough vectors and the selectivity is too high + if (vectors_processed == n_vectors_to_check) { + if (GetSelectivity() >= selectivity_threshold) { + status = FilterStatus::PAUSED_DUE_TO_HIGH_SELECTIVITY; + } + } + } +} + +bool SelectivityOptionalFilterState::SelectivityStats::IsActive() const { + return status == FilterStatus::ACTIVE; +} +double SelectivityOptionalFilterState::SelectivityStats::GetSelectivity() const { + if (tuples_processed == 0) { + return 1.0; + } + return static_cast(tuples_accepted) / static_cast(tuples_processed); +} + +SelectivityOptionalFilter::SelectivityOptionalFilter(unique_ptr filter, const float selectivity_threshold, + const idx_t n_vectors_to_check) + : OptionalFilter(std::move(filter)), selectivity_threshold(selectivity_threshold), + n_vectors_to_check(n_vectors_to_check) { +} + +FilterPropagateResult SelectivityOptionalFilter::CheckStatistics(BaseStatistics &stats) const { + // TODO: A potential optimization would be to pause the filter for this row group if the stats return always true, + // but this needs to happen thread local, as other threads scan other row groups + return child_filter->CheckStatistics(stats); +} + +void SelectivityOptionalFilter::Serialize(Serializer &serializer) const { + OptionalFilter::Serialize(serializer); + serializer.WritePropertyWithDefault(201, "selectivity_threshold", selectivity_threshold); + serializer.WritePropertyWithDefault(202, "n_vectors_to_check", n_vectors_to_check); +} + +unique_ptr SelectivityOptionalFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new SelectivityOptionalFilter(nullptr, 0.5f, 100)); + deserializer.ReadPropertyWithDefault>(200, "child_filter", result->child_filter); + deserializer.ReadPropertyWithDefault(201, "selectivity_threshold", result->selectivity_threshold); + deserializer.ReadPropertyWithDefault(202, "n_vectors_to_check", result->n_vectors_to_check); + return std::move(result); +} +void SelectivityOptionalFilter::FiltersNullValues(const LogicalType &type, bool &filters_nulls, + bool &filters_valid_values, TableFilterState &filter_state) const { + const auto &state = filter_state.Cast(); + return ConstantFun::FiltersNullValues(type, *this->child_filter, filters_nulls, filters_valid_values, + *state.child_state); +} +unique_ptr SelectivityOptionalFilter::InitializeState(ClientContext &context) const { + D_ASSERT(child_filter); + auto child_filter_state = TableFilterState::Initialize(context, *child_filter); + return make_uniq(std::move(child_filter_state), this->n_vectors_to_check, + this->selectivity_threshold); +} + +idx_t SelectivityOptionalFilter::FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, + TableFilterState &filter_state, const idx_t scan_count, + idx_t &approved_tuple_count) const { + auto &state = filter_state.Cast(); + + if (state.stats.IsActive()) { + const idx_t approved_before = approved_tuple_count; + const idx_t accepted_count = ColumnSegment::FilterSelection( + sel, vector, vdata, *child_filter, *state.child_state, scan_count, approved_tuple_count); + + state.stats.Update(accepted_count, approved_before); + return accepted_count; + } + return scan_count; +} + +unique_ptr SelectivityOptionalFilter::Copy() const { + auto copy = make_uniq(child_filter->Copy(), selectivity_threshold, n_vectors_to_check); + return duckdb::unique_ptr_cast(std::move(copy)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/logical_operator.cpp b/src/duckdb/src/planner/logical_operator.cpp index e16062573..016b7d605 100644 --- a/src/duckdb/src/planner/logical_operator.cpp +++ b/src/duckdb/src/planner/logical_operator.cpp @@ -31,6 +31,19 @@ vector LogicalOperator::GetColumnBindings() { return {ColumnBinding(0, 0)}; } +idx_t LogicalOperator::GetRootIndex() { + auto bindings = GetColumnBindings(); + if (bindings.empty()) { + throw InternalException("Empty bindings in GetRootIndex"); + } + auto root_index = bindings[0].table_index; + for (idx_t i = 1; i < bindings.size(); i++) { + if (bindings[i].table_index != root_index) { + throw InternalException("GetRootIndex - multiple column bindings found"); + } + } + return root_index; +} void LogicalOperator::SetParamsEstimatedCardinality(InsertionOrderPreservingMap &result) const { if (has_estimated_cardinality) { result[RenderTreeNode::ESTIMATED_CARDINALITY] = StringUtil::Format("%llu", estimated_cardinality); diff --git a/src/duckdb/src/planner/logical_operator_visitor.cpp b/src/duckdb/src/planner/logical_operator_visitor.cpp index 5e96a5bbb..b7723d640 100644 --- a/src/duckdb/src/planner/logical_operator_visitor.cpp +++ b/src/duckdb/src/planner/logical_operator_visitor.cpp @@ -85,7 +85,6 @@ void LogicalOperatorVisitor::VisitChildOfOperatorWithProjectionMap(LogicalOperat void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, const std::function *child)> &callback) { - switch (op.type) { case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { auto &get = op.Cast(); diff --git a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp index beee4f121..f15245c5d 100644 --- a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp +++ b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp @@ -126,7 +126,7 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria throw InternalException("Copy function \"%s\" has neither bind nor (de)serialize", function.name); } - CopyFunctionBindInput function_bind_input(*copy_info); + CopyFunctionBindInput function_bind_input(*copy_info, function.function_info); auto names_to_write = GetNamesWithoutPartitions(names, partition_columns, write_partition_columns); auto types_to_write = GetTypesWithoutPartitions(expected_types, partition_columns, write_partition_columns); bind_data = function.copy_to_bind(context, function_bind_input, names_to_write, types_to_write); diff --git a/src/duckdb/src/planner/operator/logical_create_index.cpp b/src/duckdb/src/planner/operator/logical_create_index.cpp index e1bc0f0ee..44dcab583 100644 --- a/src/duckdb/src/planner/operator/logical_create_index.cpp +++ b/src/duckdb/src/planner/operator/logical_create_index.cpp @@ -10,7 +10,6 @@ LogicalCreateIndex::LogicalCreateIndex(unique_ptr info_p, vecto TableCatalogEntry &table_p, unique_ptr alter_table_info) : LogicalOperator(LogicalOperatorType::LOGICAL_CREATE_INDEX), info(std::move(info_p)), table(table_p), alter_table_info(std::move(alter_table_info)) { - for (auto &expr : expressions_p) { unbound_expressions.push_back(expr->Copy()); } @@ -27,7 +26,6 @@ LogicalCreateIndex::LogicalCreateIndex(ClientContext &context, unique_ptr(std::move(info_p))), table(BindTable(context, *info)), alter_table_info(unique_ptr_cast(std::move(alter_table_info))) { - for (auto &expr : expressions_p) { unbound_expressions.push_back(expr->Copy()); } diff --git a/src/duckdb/src/planner/operator/logical_dependent_join.cpp b/src/duckdb/src/planner/operator/logical_dependent_join.cpp index 2e46dbc78..70af8444a 100644 --- a/src/duckdb/src/planner/operator/logical_dependent_join.cpp +++ b/src/duckdb/src/planner/operator/logical_dependent_join.cpp @@ -3,7 +3,7 @@ namespace duckdb { LogicalDependentJoin::LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), join_condition(std::move(condition)), correlated_columns(std::move(correlated_columns)) { @@ -17,7 +17,7 @@ LogicalDependentJoin::LogicalDependentJoin(JoinType join_type) unique_ptr LogicalDependentJoin::Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) { return make_uniq(std::move(left), std::move(right), std::move(correlated_columns), type, std::move(condition)); diff --git a/src/duckdb/src/planner/operator/logical_empty_result.cpp b/src/duckdb/src/planner/operator/logical_empty_result.cpp index 12c1653b3..b745228b5 100644 --- a/src/duckdb/src/planner/operator/logical_empty_result.cpp +++ b/src/duckdb/src/planner/operator/logical_empty_result.cpp @@ -4,7 +4,6 @@ namespace duckdb { LogicalEmptyResult::LogicalEmptyResult(unique_ptr op) : LogicalOperator(LogicalOperatorType::LOGICAL_EMPTY_RESULT) { - this->bindings = op->GetColumnBindings(); op->ResolveOperatorTypes(); diff --git a/src/duckdb/src/planner/operator/logical_vacuum.cpp b/src/duckdb/src/planner/operator/logical_vacuum.cpp index 36352a0ea..ce4a76951 100644 --- a/src/duckdb/src/planner/operator/logical_vacuum.cpp +++ b/src/duckdb/src/planner/operator/logical_vacuum.cpp @@ -1,5 +1,5 @@ #include "duckdb/planner/operator/logical_vacuum.hpp" - +#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -46,11 +46,14 @@ unique_ptr LogicalVacuum::Deserialize(Deserializer &deserialize auto &context = deserializer.Get(); auto binder = Binder::CreateBinder(context); auto bound_table = binder->Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw InvalidInputException("can only vacuum or analyze base tables"); + } + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { throw InvalidInputException("can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; result->SetTable(table); // FIXME: we should probably verify that the 'column_id_map' and 'columns' are the same on the bound table after // deserialization? diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index 78bca8a02..ca5e72d88 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -40,8 +40,8 @@ void Planner::CreatePlan(SQLStatement &statement) { // first bind the tables and columns to the catalog bool parameters_resolved = true; try { - profiler.StartPhase(MetricsType::PLANNER_BINDING); - binder->parameters = &bound_parameters; + profiler.StartPhase(MetricType::PLANNER_BINDING); + binder->SetParameters(bound_parameters); auto bound_statement = binder->Bind(statement); profiler.EndPhase(); diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index 7b2909c6d..23dd23ec4 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -2,6 +2,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/common/operator/add.hpp" +#include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/planner/binder.hpp" @@ -18,9 +19,8 @@ namespace duckdb { -FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated, - bool perform_delim, bool any_join, - optional_ptr parent) +FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim, + bool any_join, optional_ptr parent) : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), perform_delim(perform_delim), any_join(any_join), parent(parent) { for (idx_t i = 0; i < correlated_columns.size(); i++) { @@ -30,8 +30,7 @@ FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated_columns, +static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, const CorrelatedColumns &correlated_columns, vector bindings, idx_t base_offset, bool perform_delim) { auto col_count = perform_delim ? correlated_columns.size() : 1; for (idx_t i = 0; i < col_count; i++) { @@ -50,7 +49,7 @@ static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, unique_ptr FlattenDependentJoins::DecorrelateIndependent(Binder &binder, unique_ptr plan) { - vector correlated; + CorrelatedColumns correlated; FlattenDependentJoins flatten(binder, correlated); return flatten.Decorrelate(std::move(plan)); } @@ -80,12 +79,12 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrsecond = false; // rewrite - idx_t lateral_depth = 0; + idx_t next_lateral_depth = 0; - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, next_lateral_depth); rewriter.VisitOperator(*plan); - RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, lateral_depth, true); + RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, next_lateral_depth, true); recursive_rewriter.VisitOperator(*plan); } else { op.children[0] = Decorrelate(std::move(op.children[0])); @@ -94,8 +93,8 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptr(op.correlated_columns[0].binding.table_index); + const auto &op_col = op.correlated_columns[op.correlated_columns.GetDelimIndex()]; + auto window = make_uniq(op_col.binding.table_index); auto row_number = make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; @@ -114,9 +113,9 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1], op.is_lateral_join, lateral_depth); if (delim_join->children[1]->type == LogicalOperatorType::LOGICAL_MATERIALIZED_CTE) { - auto &cte = delim_join->children[1]->Cast(); + auto &cte_ref = delim_join->children[1]->Cast(); // check if the left side of the CTE has correlated expressions - auto entry = flatten.has_correlated_expressions.find(*cte.children[0]); + auto entry = flatten.has_correlated_expressions.find(*cte_ref.children[0]); if (entry != flatten.has_correlated_expressions.end()) { if (!entry->second) { // the left side of the CTE has no correlated expressions, we can push the DEPENDENT_JOIN down @@ -132,7 +131,7 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1] = flatten.PushDownDependentJoin(std::move(delim_join->children[1]), propagate_null_values, lateral_depth); data_offset = flatten.data_offset; - auto left_offset = delim_join->children[0]->GetColumnBindings().size(); + const auto left_offset = delim_join->children[0]->GetColumnBindings().size(); if (!parent) { delim_offset = left_offset + flatten.delim_offset; } @@ -214,7 +213,6 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptr(); + binder.recursive_ctes[setop.table_index] = &setop; + has_correlated_expressions[op] = has_correlation; + if (has_correlation) { + setop.correlated_columns = correlated_columns; + } + } + child_idx++; } @@ -263,6 +271,7 @@ bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator &op, boo return true; } // Found a materialized CTE, subtree correlation depends on the CTE node + has_correlated_expressions[op] = has_correlated_expressions[*cte_node]; return has_correlated_expressions[*cte_node]; } // No CTE found: subtree is correlated @@ -281,47 +290,32 @@ bool FlattenDependentJoins::DetectCorrelatedExpressions(LogicalOperator &op, boo binder.recursive_ctes[setop.table_index] = &setop; if (has_correlation) { setop.correlated_columns = correlated_columns; - MarkSubtreeCorrelated(*op.children[1].get()); - } - } - - if (op.type == LogicalOperatorType::LOGICAL_MATERIALIZED_CTE) { - auto &setop = op.Cast(); - binder.recursive_ctes[setop.table_index] = &setop; - // only mark the entire subtree as correlated if the materializing side is correlated - auto entry = has_correlated_expressions.find(*op.children[0]); - if (entry != has_correlated_expressions.end()) { - if (has_correlation && entry->second) { - setop.correlated_columns = correlated_columns; - MarkSubtreeCorrelated(*op.children[1].get()); - } + MarkSubtreeCorrelated(*op.children[1].get(), setop.table_index); } } return has_correlation; } -bool FlattenDependentJoins::MarkSubtreeCorrelated(LogicalOperator &op) { +bool FlattenDependentJoins::MarkSubtreeCorrelated(LogicalOperator &op, idx_t cte_index) { // Do not mark base table scans as correlated auto entry = has_correlated_expressions.find(op); D_ASSERT(entry != has_correlated_expressions.end()); bool has_correlation = entry->second; for (auto &child : op.children) { - has_correlation |= MarkSubtreeCorrelated(*child.get()); + has_correlation |= MarkSubtreeCorrelated(*child.get(), cte_index); } if (op.type != LogicalOperatorType::LOGICAL_GET || op.children.size() == 1) { if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { // There may be multiple recursive CTEs. Only mark CTE_REFs as correlated, // IFF the CTE that we are reading from is correlated. auto &cteref = op.Cast(); - auto cte = binder.recursive_ctes.find(cteref.cte_index); - bool has_correlation = false; - if (cte != binder.recursive_ctes.end()) { - auto &rec_cte = cte->second->Cast(); - has_correlation = !rec_cte.correlated_columns.empty(); + if (cteref.cte_index != cte_index) { + has_correlated_expressions[op] = has_correlation; + return has_correlation; } - has_correlated_expressions[op] = has_correlation; - return has_correlation; + has_correlated_expressions[op] = true; + return true; } else { has_correlated_expressions[op] = has_correlation; } @@ -697,6 +691,42 @@ unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal return plan; } } else if (join.join_type == JoinType::MARK) { + if (!left_has_correlation && right_has_correlation) { + // found a MARK join where the left side has no correlation + + ColumnBinding right_binding; + + // there may still be correlation on the right side that we have to deal with + // push into the right side if necessary or decorrelate it independently otherwise + plan->children[1] = PushDownDependentJoinInternal(std::move(plan->children[1]), + parent_propagate_null_values, lateral_depth); + right_binding = this->base_binding; + + // now push into the left side of the MARK join even though it has no correlation + // this is necessary to add the correlated columns to the column bindings and allow + // the join condition to be rewritten correctly + plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), + parent_propagate_null_values, lateral_depth); + + auto left_binding = this->base_binding; + + // add the correlated columns to the join conditions + for (idx_t i = 0; i < correlated_columns.size(); i++) { + JoinCondition cond; + cond.left = make_uniq( + correlated_columns[i].type, + ColumnBinding(left_binding.table_index, left_binding.column_index + i)); + cond.right = make_uniq( + correlated_columns[i].type, + ColumnBinding(right_binding.table_index, right_binding.column_index + i)); + cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + + auto &comparison_join = join.Cast(); + comparison_join.conditions.push_back(std::move(cond)); + } + return plan; + } + // push the child into the LHS plan->children[0] = PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth); @@ -998,7 +1028,6 @@ unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal } case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: { - #ifdef DEBUG plan->children[0]->ResolveOperatorTypes(); plan->children[1]->ResolveOperatorTypes(); @@ -1042,7 +1071,8 @@ unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal } } - RewriteCTEScan cte_rewriter(table_index, correlated_columns); + RewriteCTEScan cte_rewriter(table_index, correlated_columns, + plan->type == LogicalOperatorType::LOGICAL_RECURSIVE_CTE); cte_rewriter.VisitOperator(*plan->children[1]); parent_propagate_null_values = false; diff --git a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp index 9f1c679a1..8554f3f5b 100644 --- a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp @@ -7,7 +7,7 @@ namespace duckdb { -HasCorrelatedExpressions::HasCorrelatedExpressions(const vector &correlated, bool lateral, +HasCorrelatedExpressions::HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral, idx_t lateral_depth) : has_correlated_expressions(false), lateral(lateral), correlated_columns(correlated), lateral_depth(lateral_depth) { diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp index 903840dda..10004d8f7 100644 --- a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -9,7 +9,6 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_dependent_join.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { @@ -71,14 +70,14 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRef } //! Helper class used to recursively rewrite correlated expressions within nested subqueries. -class RewriteCorrelatedRecursive : public BoundNodeVisitor { +class RewriteCorrelatedRecursive : public LogicalOperatorVisitor { public: RewriteCorrelatedRecursive(ColumnBinding base_binding, column_binding_map_t &correlated_map); - void VisitBoundTableRef(BoundTableRef &ref) override; - void VisitExpression(unique_ptr &expression) override; + void VisitOperator(LogicalOperator &op) override; + void VisitExpression(unique_ptr *expression) override; - void RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery); + void RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &subquery); ColumnBinding base_binding; column_binding_map_t &correlated_map; @@ -92,7 +91,7 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryE // subquery detected within this subquery // recursively rewrite it using the RewriteCorrelatedRecursive class RewriteCorrelatedRecursive rewrite(base_binding, correlated_map); - rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery); + rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery.plan); return nullptr; } @@ -101,40 +100,30 @@ RewriteCorrelatedRecursive::RewriteCorrelatedRecursive(ColumnBinding base_bindin : base_binding(base_binding), correlated_map(correlated_map) { } -void RewriteCorrelatedRecursive::VisitBoundTableRef(BoundTableRef &ref) { - if (ref.type == TableReferenceType::JOIN) { +void RewriteCorrelatedRecursive::VisitOperator(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); - for (auto &corr : bound_join.correlated_columns) { + auto &dep_join = op.Cast(); + for (auto &corr : dep_join.correlated_columns) { auto entry = correlated_map.find(corr.binding); if (entry != correlated_map.end()) { corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); } } - } else if (ref.type == TableReferenceType::SUBQUERY) { - auto &subquery = ref.Cast(); - RewriteCorrelatedSubquery(*subquery.binder, *subquery.subquery); - return; } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } -void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery) { - // rewrite the binding in the correlated list of the subquery) - for (auto &corr : binder.correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - VisitBoundQueryNode(subquery); +void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &op) { + VisitOperator(op); } -void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &expression) { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { +void RewriteCorrelatedRecursive::VisitExpression(unique_ptr *expression) { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { // bound column reference - auto &bound_colref = expression->Cast(); + auto &bound_colref = expr.Cast(); if (bound_colref.depth == 0) { // not a correlated column, ignore return; @@ -148,13 +137,13 @@ void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &express bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); bound_colref.depth--; } - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { // we encountered another subquery: rewrite recursively - auto &bound_subquery = expression->Cast(); - RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery); + auto &bound_subquery = expr.Cast(); + RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery.plan); } // recurse into the children of this subquery - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t &replacement_map) diff --git a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp index 78b3b21ec..7df4f13a8 100644 --- a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp @@ -14,8 +14,10 @@ namespace duckdb { -RewriteCTEScan::RewriteCTEScan(idx_t table_index, const vector &correlated_columns) - : table_index(table_index), correlated_columns(correlated_columns) { +RewriteCTEScan::RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns, + bool rewrite_dependent_joins) + : table_index(table_index), correlated_columns(correlated_columns), + rewrite_dependent_joins(rewrite_dependent_joins) { } void RewriteCTEScan::VisitOperator(LogicalOperator &op) { @@ -29,7 +31,7 @@ void RewriteCTEScan::VisitOperator(LogicalOperator &op) { } cteref.correlated_columns += correlated_columns.size(); } - } else if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + } else if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN && rewrite_dependent_joins) { // There is another DependentJoin below the correlated recursive CTE. // We have to add the correlated columns of the recursive CTE to the // set of columns of this operator. @@ -49,7 +51,7 @@ void RewriteCTEScan::VisitOperator(LogicalOperator &op) { // The correlated columns must be placed at the beginning of the // correlated_columns list. Otherwise, further column accesses // and rewrites will fail. - join.correlated_columns.emplace(join.correlated_columns.begin(), corr); + join.correlated_columns.AddColumn(std::move(corr)); } } } diff --git a/src/duckdb/src/planner/table_binding.cpp b/src/duckdb/src/planner/table_binding.cpp index d9bdd71c7..c55d0be82 100644 --- a/src/duckdb/src/planner/table_binding.cpp +++ b/src/duckdb/src/planner/table_binding.cpp @@ -19,6 +19,10 @@ Binding::Binding(BindingType binding_type, BindingAlias alias_p, vector &Binding::GetColumnTypes() { + return types; +} + +const vector &Binding::GetColumnNames() { + return names; +} + +idx_t Binding::GetColumnCount() { + return GetColumnNames().size(); +} + +void Binding::SetColumnType(idx_t col_idx, LogicalType type_p) { + types[col_idx] = std::move(type_p); +} + string Binding::GetAlias() const { return alias.GetAlias(); } @@ -304,4 +336,42 @@ unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colre return arg; } +CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, idx_t index, + CTEType cte_type) + : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), cte_type(cte_type), + reference_count(0) { +} + +CTEBinding::CTEBinding(BindingAlias alias_p, shared_ptr bind_state_p, idx_t index) + : Binding(BindingType::CTE, std::move(alias_p), vector(), vector(), index), + cte_type(CTEType::CAN_BE_REFERENCED), reference_count(0), bind_state(std::move(bind_state_p)) { +} + +bool CTEBinding::CanBeReferenced() const { + return cte_type == CTEType::CAN_BE_REFERENCED; +} + +bool CTEBinding::IsReferenced() const { + return reference_count > 0; +} + +void CTEBinding::Reference() { + if (!CanBeReferenced()) { + throw InternalException("CTE cannot be referenced!"); + } + if (bind_state) { + // we have not bound the CTE yet - bind it + bind_state->Bind(*this); + + // copy over the names / types and initialize the binding + this->names = bind_state->names; + this->types = bind_state->types; + Initialize(); + + // finalize binding + bind_state.reset(); + } + reference_count++; +} + } // namespace duckdb diff --git a/src/duckdb/src/planner/table_filter_state.cpp b/src/duckdb/src/planner/table_filter_state.cpp index c542939dc..258938a6e 100644 --- a/src/duckdb/src/planner/table_filter_state.cpp +++ b/src/duckdb/src/planner/table_filter_state.cpp @@ -1,6 +1,8 @@ #include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" #include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include "duckdb/planner/filter/struct_filter.hpp" namespace duckdb { @@ -11,9 +13,16 @@ ExpressionFilterState::ExpressionFilterState(ClientContext &context, const Expre unique_ptr TableFilterState::Initialize(ClientContext &context, const TableFilter &filter) { switch (filter.filter_type) { - case TableFilterType::OPTIONAL_FILTER: - // optional filter is not executed - create an empty filter state - return make_uniq(); + case TableFilterType::BLOOM_FILTER: { + auto &bf = filter.Cast(); + return make_uniq(bf.GetKeyType()); + } + case TableFilterType::OPTIONAL_FILTER: { + // the optional filter may be executed if it is a SelectivityOptionalFilter + auto &optional_filter = filter.Cast(); + return optional_filter.InitializeState(context); + } + case TableFilterType::STRUCT_EXTRACT: { auto &struct_filter = filter.Cast(); return Initialize(context, *struct_filter.child_filter); diff --git a/src/duckdb/src/storage/block.cpp b/src/duckdb/src/storage/block.cpp index 7262277e5..d45c2584e 100644 --- a/src/duckdb/src/storage/block.cpp +++ b/src/duckdb/src/storage/block.cpp @@ -4,16 +4,16 @@ namespace duckdb { -Block::Block(Allocator &allocator, const block_id_t id, const idx_t block_size, const idx_t block_header_size) +Block::Block(BlockAllocator &allocator, const block_id_t id, const idx_t block_size, const idx_t block_header_size) : FileBuffer(allocator, FileBufferType::BLOCK, block_size, block_header_size), id(id) { } -Block::Block(Allocator &allocator, block_id_t id, uint32_t internal_size, idx_t block_header_size) +Block::Block(BlockAllocator &allocator, block_id_t id, uint32_t internal_size, idx_t block_header_size) : FileBuffer(allocator, FileBufferType::BLOCK, internal_size, block_header_size), id(id) { D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); } -Block::Block(Allocator &allocator, block_id_t id, BlockManager &block_manager) +Block::Block(BlockAllocator &allocator, block_id_t id, BlockManager &block_manager) : FileBuffer(allocator, FileBufferType::BLOCK, block_manager), id(id) { D_ASSERT((AllocSize() & (Storage::SECTOR_SIZE - 1)) == 0); } diff --git a/src/duckdb/src/storage/block_allocator.cpp b/src/duckdb/src/storage/block_allocator.cpp new file mode 100644 index 000000000..d93d97328 --- /dev/null +++ b/src/duckdb/src/storage/block_allocator.cpp @@ -0,0 +1,405 @@ +#include "duckdb/storage/block_allocator.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/parallel/concurrentqueue.hpp" +#include "duckdb/common/types/uuid.hpp" + +#if defined(_WIN32) +#include "duckdb/common/windows.hpp" +#else +#include +#endif +#ifdef __MVS__ +#include +#endif + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Memory Helpers +//===--------------------------------------------------------------------===// +static data_ptr_t AllocateVirtualMemory(const idx_t size) { +#if INTPTR_MAX == INT32_MAX + // Disable on 32-bit + return nullptr; +#endif + +#if defined(_WIN32) + // This returns nullptr on failure + return data_ptr_t(VirtualAlloc(nullptr, size, MEM_RESERVE, PAGE_NOACCESS)); +#else + const auto ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + return ptr == MAP_FAILED ? nullptr : data_ptr_cast(ptr); +#endif +} + +static void FreeVirtualMemory(const data_ptr_t pointer, const idx_t size) { + bool success; +#if defined(_WIN32) + success = VirtualFree(pointer, 0, MEM_RELEASE); +#else + success = munmap(pointer, size) == 0; +#endif + if (!success) { + throw InternalException("FreeVirtualMemory failed"); + } +} + +static void OnFirstAllocation(const data_ptr_t pointer, const idx_t size) { + bool success = true; +#if defined(_WIN32) + success = VirtualAlloc(pointer, size, MEM_COMMIT, PAGE_READWRITE); +#elif defined(__APPLE__) + // Nothing to do here +#else + // Pre-fault the memory + for (idx_t i = 0; i < size; i += 4096) { + pointer[i] = 0; + } +#endif + if (!success) { + throw InternalException("OnFirstAllocation failed"); + } +} + +static void OnDeallocation(const data_ptr_t pointer, const idx_t size) { + bool success; +#if defined(_WIN32) + success = VirtualFree(pointer, size, MEM_DECOMMIT); +#elif defined(__APPLE__) + success = madvise(pointer, size, MADV_FREE_REUSABLE) == 0; +#elif defined(__MVS__) + // the madvice functionality is not available on z/OS in any form + success = true; +#else + success = madvise(pointer, size, MADV_DONTNEED) == 0; +#endif + if (!success) { + throw InternalException("OnDeallocation failed"); + } +} + +//===--------------------------------------------------------------------===// +// BlockAllocatorThreadLocalState +//===--------------------------------------------------------------------===// +struct BlockQueue { + duckdb_moodycamel::ConcurrentQueue q; +}; + +class BlockAllocatorThreadLocalState { +public: + explicit BlockAllocatorThreadLocalState(const BlockAllocator &block_allocator_p) { + Initialize(block_allocator_p); + } + ~BlockAllocatorThreadLocalState() { + Clear(); + } + +public: + void TryInitialize(const BlockAllocator &block_allocator_p) { + // Local state can be invalidated if DB closes but thread stays alive + if (cached_uuid != block_allocator_p.uuid) { + Initialize(block_allocator_p); + } + } + + data_ptr_t Allocate() { + auto pointer = TryAllocateFromLocal(); + if (pointer) { + return pointer; + } + + // We have run out of local blocks + if (TryGetBatch(touched, *block_allocator->touched) || TryGetBatch(untouched, *block_allocator->untouched)) { + // We have refilled local blocks + pointer = TryAllocateFromLocal(); + D_ASSERT(pointer); + return pointer; + } + + // We have also run out of global blocks, use fallback allocator + return block_allocator->allocator.AllocateData(block_allocator->block_size); + } + + void Free(const data_ptr_t pointer) { + touched.push_back(block_allocator->GetBlockID(pointer)); + if (touched.size() < FREE_THRESHOLD) { + return; + } + + // Upon reaching the threshold, we return a local batch to global + std::sort(touched.begin(), touched.end()); + block_allocator->touched->q.enqueue_bulk(touched.end() - BATCH_SIZE, BATCH_SIZE); + touched.resize(touched.size() - BATCH_SIZE); + } + + void Clear() { + // Return all local blocks back to global + if (!touched.empty()) { + block_allocator->touched->q.enqueue_bulk(touched.begin(), touched.size()); + touched.clear(); + } + if (!untouched.empty()) { + block_allocator->untouched->q.enqueue_bulk(untouched.begin(), untouched.size()); + untouched.clear(); + } + } + +private: + void Initialize(const BlockAllocator &block_allocator_p) { + cached_uuid = block_allocator_p.uuid; + block_allocator = block_allocator_p; + untouched.clear(); + touched.clear(); + untouched.reserve(BATCH_SIZE); + touched.reserve(FREE_THRESHOLD); + } + + data_ptr_t TryAllocateFromLocal() { + if (!touched.empty()) { + const auto pointer = block_allocator->GetPointer(touched.back()); + touched.pop_back(); + return pointer; + } + if (!untouched.empty()) { + const auto pointer = block_allocator->GetPointer(untouched.back()); + untouched.pop_back(); + OnFirstAllocation(pointer, block_allocator->block_size); + return pointer; + } + return nullptr; + } + + static bool TryGetBatch(vector &local, BlockQueue &global) { + D_ASSERT(local.empty()); + local.resize(BATCH_SIZE); + const auto size = global.q.try_dequeue_bulk(local.begin(), BATCH_SIZE); + local.resize(size); + std::sort(local.begin(), local.end()); + return !local.empty(); + } + +private: + hugeint_t cached_uuid; + optional_ptr block_allocator; + + static constexpr idx_t BATCH_SIZE = 128; + static constexpr idx_t FREE_THRESHOLD = BATCH_SIZE * 2; + + vector untouched; + vector touched; +}; + +BlockAllocatorThreadLocalState &GetBlockAllocatorThreadLocalState(const BlockAllocator &block_allocator) { +#ifdef __MVS__ + auto allocator_state = BlockAllocatorThreadLocalState(block_allocator); + static __tlssim local_state_impl(allocator_state); + auto *local_state = local_state_impl.access(); + (*local_state).TryInitialize(block_allocator); + return *local_state; +#else + thread_local BlockAllocatorThreadLocalState local_state(block_allocator); + local_state.TryInitialize(block_allocator); + return local_state; +#endif +} + +//===--------------------------------------------------------------------===// +// BlockAllocator +//===--------------------------------------------------------------------===// +BlockAllocator::BlockAllocator(Allocator &allocator_p, const idx_t block_size_p, const idx_t virtual_memory_size_p, + const idx_t physical_memory_size_p) + : uuid(UUID::GenerateRandomUUID()), allocator(allocator_p), block_size(block_size_p), + block_size_div_shift(CountZeros::Trailing(block_size)), + virtual_memory_size(AlignValue(virtual_memory_size_p, block_size)), + virtual_memory_space(AllocateVirtualMemory(virtual_memory_size)), physical_memory_size(0), + untouched(make_unsafe_uniq()), touched(make_unsafe_uniq()) { + D_ASSERT(IsPowerOfTwo(block_size)); + Resize(physical_memory_size_p); +} + +BlockAllocator::~BlockAllocator() { + GetBlockAllocatorThreadLocalState(*this).Clear(); + if (IsActive()) { + FreeVirtualMemory(virtual_memory_space, virtual_memory_size); + } +} + +BlockAllocator &BlockAllocator::Get(DatabaseInstance &db) { + return *db.config.block_allocator; +} + +BlockAllocator &BlockAllocator::Get(AttachedDatabase &db) { + return Get(db.GetDatabase()); +} + +void BlockAllocator::Resize(const idx_t new_physical_memory_size) { + if (!IsActive()) { + return; + } + + lock_guard guard(physical_memory_lock); + if (new_physical_memory_size < physical_memory_size) { + throw InvalidInputException("The \"block_allocator_size\" setting cannot be reduced (current: %llu)", + physical_memory_size.load()); + } + if (new_physical_memory_size > virtual_memory_size) { + throw InvalidInputException("The \"block_allocator_size\" setting cannot be greater than the virtual memory " + "size (virtual memory size: %llu)", + virtual_memory_size); + } + + // Enqueue block IDs efficiently in batches + uint32_t block_ids[STANDARD_VECTOR_SIZE]; + const auto start = NumericCast(DivBlockSize(physical_memory_size)); + const auto end = NumericCast(DivBlockSize(new_physical_memory_size)); + for (auto block_id = start; block_id < end; block_id += STANDARD_VECTOR_SIZE) { + const auto next = MinValue(end - block_id, STANDARD_VECTOR_SIZE); + for (uint32_t i = 0; i < next; i++) { + block_ids[i] = block_id + i; + } + untouched->q.enqueue_bulk(block_ids, next); + } + + // Finally, update to the new size + physical_memory_size = new_physical_memory_size; +} + +bool BlockAllocator::IsActive() const { + return virtual_memory_space; +} + +bool BlockAllocator::IsEnabled() const { + return physical_memory_size.load(std::memory_order_relaxed) != 0; +} + +bool BlockAllocator::IsInPool(const data_ptr_t pointer) const { + return pointer >= virtual_memory_space && pointer < virtual_memory_space + virtual_memory_size; +} + +idx_t BlockAllocator::ModuloBlockSize(const idx_t n) const { + return n & (block_size - 1); +} + +idx_t BlockAllocator::DivBlockSize(const idx_t n) const { + return n >> block_size_div_shift; +} + +uint32_t BlockAllocator::GetBlockID(const data_ptr_t pointer) const { + D_ASSERT(IsInPool(pointer)); + const auto offset = NumericCast(pointer - virtual_memory_space); + D_ASSERT(ModuloBlockSize(offset) == 0); + const auto block_id = NumericCast(DivBlockSize(offset)); + VerifyBlockID(block_id); + return block_id; +} + +void BlockAllocator::VerifyBlockID(const uint32_t block_id) const { + D_ASSERT(block_id < NumericCast(virtual_memory_size / block_size)); +} + +data_ptr_t BlockAllocator::GetPointer(const uint32_t block_id) const { + VerifyBlockID(block_id); + return virtual_memory_space + NumericCast(block_id) * block_size; +} + +data_ptr_t BlockAllocator::AllocateData(const idx_t size) const { + if (!IsActive() || !IsEnabled() || size != block_size) { + return allocator.AllocateData(size); + } + return GetBlockAllocatorThreadLocalState(*this).Allocate(); +} + +void BlockAllocator::FreeData(const data_ptr_t pointer, const idx_t size) const { + if (!IsActive() || !IsInPool(pointer)) { + return allocator.FreeData(pointer, size); + } + D_ASSERT(size == block_size); + GetBlockAllocatorThreadLocalState(*this).Free(pointer); +} + +data_ptr_t BlockAllocator::ReallocateData(const data_ptr_t pointer, const idx_t old_size, const idx_t new_size) const { + if (old_size == new_size) { + return pointer; + } + + // If both the old and new allocation are not (or cannot be) in the pool, immediately use the fallback allocator + if (!IsActive() || (!IsInPool(pointer) && new_size != block_size)) { + return allocator.ReallocateData(pointer, old_size, new_size); + } + + // Either old or new can be in the pool: allocate, copy, and free + const auto new_pointer = AllocateData(new_size); + memcpy(new_pointer, pointer, MinValue(old_size, new_size)); + FreeData(pointer, old_size); + return new_pointer; +} + +bool BlockAllocator::SupportsFlush() const { + return (IsActive() && IsEnabled()) || Allocator::SupportsFlush(); +} + +void BlockAllocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) const { + if (IsActive() && IsEnabled()) { + GetBlockAllocatorThreadLocalState(*this).Clear(); + } + if (Allocator::SupportsFlush()) { + Allocator::ThreadFlush(allocator_background_threads, threshold, thread_count); + } +} + +void BlockAllocator::FlushAll(const optional_idx extra_memory) const { + if (IsActive() && IsEnabled() && extra_memory.IsValid()) { + FreeInternal(extra_memory.GetIndex()); + } + if (Allocator::SupportsFlush()) { + Allocator::FlushAll(); + } +} + +void BlockAllocator::FreeInternal(const idx_t extra_memory) const { + auto count = DivBlockSize(extra_memory); + unsafe_vector to_free_buffer; + to_free_buffer.resize(count); + count = touched->q.try_dequeue_bulk(to_free_buffer.begin(), count); + if (count == 0) { + return; + } + to_free_buffer.resize(count); + + // Sort so we can coalesce madvise calls + std::sort(to_free_buffer.begin(), to_free_buffer.end()); + + // Coalesce and free + uint32_t block_id_start = to_free_buffer[0]; + for (idx_t i = 1; i < to_free_buffer.size(); i++) { + const auto &previous_block_id = to_free_buffer[i - 1]; + const auto ¤t_block_id = to_free_buffer[i]; + if (previous_block_id == current_block_id - 1) { + continue; // Current is contiguous with previous block + } + + // Previous block is the last contiguous block starting from block_id_start, free them in one go + FreeContiguousBlocks(block_id_start, previous_block_id); + + // Continue coalescing from the current + block_id_start = current_block_id; + } + + // Don't forget the last one + FreeContiguousBlocks(block_id_start, to_free_buffer.back()); + + // Make freed blocks available to allocate again + untouched->q.enqueue_bulk(to_free_buffer.begin(), to_free_buffer.size()); +} + +void BlockAllocator::FreeContiguousBlocks(const uint32_t block_id_start, const uint32_t block_id_end_including) const { + const auto pointer = GetPointer(block_id_start); + const auto num_blocks = block_id_end_including - block_id_start + 1; + const auto size = num_blocks * block_size; + OnDeallocation(pointer, size); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp index dd9671af2..91c98d4b4 100644 --- a/src/duckdb/src/storage/buffer/block_handle.cpp +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -60,7 +60,6 @@ BlockHandle::~BlockHandle() { // NOLINT: allow internal exceptions unique_ptr AllocateBlock(BlockManager &block_manager, unique_ptr reusable_buffer, block_id_t block_id) { - if (reusable_buffer && reusable_buffer->GetHeaderSize() == block_manager.GetBlockHeaderSize()) { // re-usable buffer: re-use it if (reusable_buffer->GetBufferType() == FileBufferType::BLOCK) { diff --git a/src/duckdb/src/storage/buffer/block_manager.cpp b/src/duckdb/src/storage/buffer/block_manager.cpp index 47fef0ebf..5a3fd0f47 100644 --- a/src/duckdb/src/storage/buffer/block_manager.cpp +++ b/src/duckdb/src/storage/buffer/block_manager.cpp @@ -14,6 +14,22 @@ BlockManager::BlockManager(BufferManager &buffer_manager, const optional_idx blo block_alloc_size(block_alloc_size_p), block_header_size(block_header_size_p) { } +bool BlockManager::BlockIsRegistered(block_id_t block_id) { + lock_guard lock(blocks_lock); + // check if the block already exists + auto entry = blocks.find(block_id); + if (entry == blocks.end()) { + return false; + } + // already exists: check if it hasn't expired yet + auto existing_ptr = entry->second.lock(); + if (existing_ptr) { + //! it hasn't! return it + return true; + } + return false; +} + shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { lock_guard lock(blocks_lock); // check if the block already exists @@ -34,19 +50,32 @@ shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { } shared_ptr BlockManager::ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block, BufferHandle old_handle) { + shared_ptr old_block, BufferHandle old_handle, + ConvertToPersistentMode mode) { // register a block with the new block id auto new_block = RegisterBlock(block_id); D_ASSERT(new_block->GetState() == BlockState::BLOCK_UNLOADED); D_ASSERT(new_block->Readers() == 0); + if (mode == ConvertToPersistentMode::THREAD_SAFE) { + // safe mode - create a copy of the old block and operate on that + // this ensures we don't modify the old block - which allows other concurrent operations on the old block to + // continue + auto old_block_copy = buffer_manager.AllocateMemory(old_block->GetMemoryTag(), this, false); + auto copy_pin = buffer_manager.Pin(old_block_copy); + memcpy(copy_pin.Ptr(), old_handle.Ptr(), GetBlockSize()); + old_block = std::move(old_block_copy); + old_handle = std::move(copy_pin); + } + auto lock = old_block->GetLock(); D_ASSERT(old_block->GetState() == BlockState::BLOCK_LOADED); D_ASSERT(old_block->GetBuffer(lock)); if (old_block->Readers() > 1) { - throw InternalException("BlockManager::ConvertToPersistent - cannot be called for block %d as old_block has " - "multiple readers active", - block_id); + throw InternalException( + "BlockManager::ConvertToPersistent in destructive mode - cannot be called for block %d as old_block has " + "multiple readers active", + block_id); } // Temp buffers can be larger than the storage block size. @@ -76,10 +105,11 @@ shared_ptr BlockManager::ConvertToPersistent(QueryContext context, } shared_ptr BlockManager::ConvertToPersistent(QueryContext context, block_id_t block_id, - shared_ptr old_block) { + shared_ptr old_block, + ConvertToPersistentMode mode) { // pin the old block to ensure we have it loaded in memory auto handle = buffer_manager.Pin(old_block); - return ConvertToPersistent(context, block_id, std::move(old_block), std::move(handle)); + return ConvertToPersistent(context, block_id, std::move(old_block), std::move(handle), mode); } void BlockManager::UnregisterBlock(block_id_t id) { @@ -95,9 +125,7 @@ void BlockManager::UnregisterBlock(BlockHandle &block) { // in-memory buffer: buffer could have been offloaded to disk: remove the file buffer_manager.DeleteTemporaryFile(block); } else { - lock_guard lock(blocks_lock); - // on-disk block: erase from list of blocks in manager - blocks.erase(id); + UnregisterBlock(id); } } diff --git a/src/duckdb/src/storage/buffer/buffer_pool.cpp b/src/duckdb/src/storage/buffer/buffer_pool.cpp index b974dab30..2f2183b9c 100644 --- a/src/duckdb/src/storage/buffer/buffer_pool.cpp +++ b/src/duckdb/src/storage/buffer/buffer_pool.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/typedefs.hpp" #include "duckdb/parallel/concurrentqueue.hpp" #include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/storage/block_allocator.hpp" #include "duckdb/storage/temporary_memory_manager.hpp" namespace duckdb { @@ -229,13 +230,13 @@ void EvictionQueue::PurgeIteration(const idx_t purge_size) { total_dead_nodes -= actually_dequeued - alive_nodes; } -BufferPool::BufferPool(idx_t maximum_memory, bool track_eviction_timestamps, +BufferPool::BufferPool(BlockAllocator &block_allocator, idx_t maximum_memory, bool track_eviction_timestamps, idx_t allocator_bulk_deallocation_flush_threshold) : eviction_queue_sizes({BLOCK_AND_EXTERNAL_FILE_QUEUE_SIZE, MANAGED_BUFFER_QUEUE_SIZE, TINY_BUFFER_QUEUE_SIZE}), maximum_memory(maximum_memory), allocator_bulk_deallocation_flush_threshold(allocator_bulk_deallocation_flush_threshold), track_eviction_timestamps(track_eviction_timestamps), - temporary_memory_manager(make_uniq()) { + temporary_memory_manager(make_uniq()), block_allocator(block_allocator) { for (idx_t queue_type_idx = 0; queue_type_idx < EVICTION_QUEUE_TYPES; queue_type_idx++) { const auto types = EvictionQueueTypeIdxToFileBufferTypes(queue_type_idx); const auto &type_queue_size = eviction_queue_sizes[queue_type_idx]; @@ -333,8 +334,8 @@ BufferPool::EvictionResult BufferPool::EvictBlocksInternal(EvictionQueue &queue, bool found = false; if (memory_usage.GetUsedMemory(MemoryUsageCaches::NO_FLUSH) <= memory_limit) { - if (Allocator::SupportsFlush() && extra_memory > allocator_bulk_deallocation_flush_threshold) { - Allocator::FlushAll(); + if (extra_memory > allocator_bulk_deallocation_flush_threshold) { + block_allocator.FlushAll(extra_memory); } return {true, std::move(r)}; } @@ -362,8 +363,8 @@ BufferPool::EvictionResult BufferPool::EvictBlocksInternal(EvictionQueue &queue, if (!found) { r.Resize(0); - } else if (Allocator::SupportsFlush() && extra_memory > allocator_bulk_deallocation_flush_threshold) { - Allocator::FlushAll(); + } else if (extra_memory > allocator_bulk_deallocation_flush_threshold) { + block_allocator.FlushAll(extra_memory); } return {found, std::move(r)}; @@ -454,9 +455,7 @@ void BufferPool::SetLimit(idx_t limit, const char *exception_postscript) { "Failed to change memory limit to %lld: could not free up enough memory for the new limit%s", limit, exception_postscript); } - if (Allocator::SupportsFlush()) { - Allocator::FlushAll(); - } + block_allocator.FlushAll(); } void BufferPool::SetAllocatorBulkDeallocationFlushThreshold(idx_t threshold) { diff --git a/src/duckdb/src/storage/caching_file_system.cpp b/src/duckdb/src/storage/caching_file_system.cpp index 3de905228..1edffb7b2 100644 --- a/src/duckdb/src/storage/caching_file_system.cpp +++ b/src/duckdb/src/storage/caching_file_system.cpp @@ -41,6 +41,9 @@ CachingFileHandle::CachingFileHandle(QueryContext context, CachingFileSystem &ca const auto &open_options = path.extended_info->options; const auto validate_entry = open_options.find("validate_external_file_cache"); if (validate_entry != open_options.end()) { + if (validate_entry->second.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for validate_external_file_cache"); + } validate = BooleanValue::Get(validate_entry->second); } } @@ -79,6 +82,21 @@ FileHandle &CachingFileHandle::GetFileHandle() { return *file_handle; } +static bool ShouldExpandToFillGap(const idx_t current_length, const idx_t added_length) { + const idx_t MAX_BOUND_TO_BE_ADDED_LENGTH = 1048576; + + if (added_length > MAX_BOUND_TO_BE_ADDED_LENGTH) { + // Absolute value of what would be needed to added is too high + return false; + } + if (added_length > current_length) { + // Relative value of what would be needed to added is too high + return false; + } + + return true; +} + BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, const idx_t location) { BufferHandle result; if (!external_file_cache.IsEnabled()) { @@ -90,30 +108,42 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, c // Try to read from the cache, filling overlapping_ranges in the process vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges); + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges, start_location_of_next_range); if (result.IsValid()) { return result; // Success } + idx_t new_nr_bytes = nr_bytes; + if (start_location_of_next_range.IsValid()) { + const idx_t nr_bytes_to_be_added = start_location_of_next_range.GetIndex() - location - nr_bytes; + if (ShouldExpandToFillGap(nr_bytes, nr_bytes_to_be_added)) { + // Grow the range from location to start_location_of_next_range, so that to fill gaps in the cached ranges + new_nr_bytes = nr_bytes + nr_bytes_to_be_added; + } + } + // Finally, if we weren't able to find the file range in the cache, we have to create a new file range - result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, nr_bytes); - auto new_file_range = make_shared_ptr(result.GetBlockHandle(), nr_bytes, location, version_tag); + result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, new_nr_bytes); + auto new_file_range = + make_shared_ptr(result.GetBlockHandle(), new_nr_bytes, location, version_tag); buffer = result.Ptr(); // Interleave reading and copying from cached buffers if (OnDiskFile()) { // On-disk file: prefer interleaving reading and copying from cached buffers - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real reads - if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, false) <= 1) { - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real + // reads + if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, false) <= 1) { + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - GetFileHandle().Read(context, buffer, nr_bytes, location); + GetFileHandle().Read(context, buffer, new_nr_bytes, location); } } - return TryInsertFileRange(result, buffer, nr_bytes, location, new_file_range); + return TryInsertFileRange(result, buffer, new_nr_bytes, location, new_file_range); } BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { @@ -131,7 +161,12 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { // Try to read from the cache first vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges); + { + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges, start_location_of_next_range); + // start_location_of_next_range is in this case discarded + } + if (result.IsValid()) { position += nr_bytes; return result; // Success @@ -213,8 +248,20 @@ const string &CachingFileHandle::GetVersionTag(const unique_ptr return cached_file.VersionTag(guard); } +idx_t CachingFileHandle::SeekPosition() { + return position; +} + +void CachingFileHandle::Seek(idx_t location) { + position = location; + if (file_handle != nullptr) { + file_handle->Seek(location); + } +} + BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges) { + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range) { BufferHandle result; // Get read lock for cached ranges @@ -246,7 +293,8 @@ BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_by } while (it != ranges.end()) { if (it->second->location >= this_end) { - // We're past the requested location + // We're past the requested location, we are going to bail out, save start_location_of_next_range + start_location_of_next_range = it->second->location; break; } // Check if the cached range overlaps the requested one diff --git a/src/duckdb/src/storage/caching_file_system_wrapper.cpp b/src/duckdb/src/storage/caching_file_system_wrapper.cpp new file mode 100644 index 000000000..7af44fb70 --- /dev/null +++ b/src/duckdb/src/storage/caching_file_system_wrapper.cpp @@ -0,0 +1,372 @@ +#include "duckdb/storage/caching_file_system_wrapper.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" + +namespace duckdb { + +//===----------------------------------------------------------------------===// +// CachingFileHandleWrapper implementation +//===----------------------------------------------------------------------===// +CachingFileHandleWrapper::CachingFileHandleWrapper(CachingFileSystemWrapper &file_system, + unique_ptr handle, FileOpenFlags flags) + : FileHandle(file_system, handle->GetPath(), flags), caching_handle(std::move(handle)) { + // Flags should already be validated to be read-only in OpenFileExtended +} + +CachingFileHandleWrapper::~CachingFileHandleWrapper() { +} + +void CachingFileHandleWrapper::Close() { + if (caching_handle) { + caching_handle.reset(); + } +} + +//===----------------------------------------------------------------------===// +// CachingFileSystemWrapper implementation +//===----------------------------------------------------------------------===// +CachingFileSystemWrapper::CachingFileSystemWrapper(FileSystem &file_system, DatabaseInstance &db, CachingMode mode) + : caching_file_system(file_system, db), underlying_file_system(file_system), caching_mode(mode) { +} + +CachingFileSystemWrapper CachingFileSystemWrapper::Get(ClientContext &context, CachingMode mode) { + return CachingFileSystemWrapper(FileSystem::GetFileSystem(context), *context.db, mode); +} + +bool CachingFileSystemWrapper::ShouldUseCache(const string &path) const { + if (caching_mode == CachingMode::ALWAYS_CACHE) { + return true; + } + return FileSystem::IsRemoteFile(path); +} + +CachingFileHandle *CachingFileSystemWrapper::GetCachingHandleIfPossible(FileHandle &handle) { + const auto &filepath = handle.GetPath(); + if (!ShouldUseCache(filepath)) { + return nullptr; + } + auto &wrapper = handle.Cast(); + return wrapper.caching_handle.get(); +} + +CachingFileSystemWrapper::~CachingFileSystemWrapper() { +} + +std::string CachingFileSystemWrapper::GetName() const { + return "CachingFileSystemWrapper"; +} + +//===----------------------------------------------------------------------===// +// Write Operations (Not Supported - Read-Only Filesystem) +//===----------------------------------------------------------------------===// +void CachingFileSystemWrapper::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + throw NotImplementedException("CachingFileSystemWrapper: Write operations are not supported. " + "CachingFileSystemWrapper is a read-only caching filesystem."); +} + +int64_t CachingFileSystemWrapper::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + throw NotImplementedException("CachingFileSystemWrapper: Write operations are not supported. " + "CachingFileSystemWrapper is a read-only caching filesystem."); +} + +bool CachingFileSystemWrapper::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_bytes) { + throw NotImplementedException("CachingFileSystemWrapper: Trim operations are not supported. " + "CachingFileSystemWrapper is a read-only caching filesystem."); +} + +void CachingFileSystemWrapper::Truncate(FileHandle &handle, int64_t new_size) { + throw NotImplementedException("CachingFileSystemWrapper: Truncate operations are not supported. " + "CachingFileSystemWrapper is a read-only caching filesystem."); +} + +void CachingFileSystemWrapper::FileSync(FileHandle &handle) { + throw NotImplementedException("CachingFileSystemWrapper: FileSync operations are not supported. " + "CachingFileSystemWrapper is a read-only caching filesystem."); +} + +//===----------------------------------------------------------------------===// +// OpenFile Operations +//===----------------------------------------------------------------------===// +unique_ptr CachingFileSystemWrapper::OpenFile(const string &path, FileOpenFlags flags, + optional_ptr opener) { + return OpenFile(OpenFileInfo(path), flags, opener); +} + +unique_ptr CachingFileSystemWrapper::OpenFile(const OpenFileInfo &path, FileOpenFlags flags, + optional_ptr opener) { + if (SupportsOpenFileExtended()) { + return OpenFileExtended(path, flags, opener); + } + throw NotImplementedException("CachingFileSystemWrapper: OpenFile is not implemented!"); +} + +unique_ptr CachingFileSystemWrapper::OpenFileExtended(const OpenFileInfo &path, FileOpenFlags flags, + optional_ptr opener) { + if (flags.OpenForWriting()) { + throw NotImplementedException("CachingFileSystemWrapper: Cannot open file for writing. " + "CachingFileSystemWrapper is a read-only caching filesystem."); + } + if (!flags.OpenForReading()) { + throw NotImplementedException("CachingFileSystemWrapper: File must be opened for reading. " + "CachingFileSystemWrapper is a read-only caching filesystem."); + } + + if (ShouldUseCache(path.path)) { + auto caching_handle = caching_file_system.OpenFile(path, flags); + return make_uniq(*this, std::move(caching_handle), flags); + } + // Bypass cache, use underlying file system directly. + return underlying_file_system.OpenFile(path, flags, opener); +} + +bool CachingFileSystemWrapper::SupportsOpenFileExtended() const { + return true; +} + +//===----------------------------------------------------------------------===// +// Read Operations +//===----------------------------------------------------------------------===// +void CachingFileSystemWrapper::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.Read(handle, buffer, nr_bytes, location); + } + + data_ptr_t cached_buffer = nullptr; + auto buffer_handle = caching_handle->Read(cached_buffer, NumericCast(nr_bytes), location); + if (!buffer_handle.IsValid()) { + throw IOException("Failed to read from caching file handle: file=\"%s\", offset=%llu, bytes=%lld", + handle.GetPath().c_str(), location, nr_bytes); + } + + // Copy data from cached buffer handle to user's buffer. + memcpy(buffer, cached_buffer, NumericCast(nr_bytes)); +} + +int64_t CachingFileSystemWrapper::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + idx_t current_position = SeekPosition(handle); + Read(handle, buffer, nr_bytes, current_position); + Seek(handle, current_position + NumericCast(nr_bytes)); + return nr_bytes; +} + +//===----------------------------------------------------------------------===// +// File Metadata Operations +//===----------------------------------------------------------------------===// +int64_t CachingFileSystemWrapper::GetFileSize(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.GetFileSize(handle); + } + + return NumericCast(caching_handle->GetFileSize()); +} + +timestamp_t CachingFileSystemWrapper::GetLastModifiedTime(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.GetLastModifiedTime(handle); + } + + return caching_handle->GetLastModifiedTime(); +} + +string CachingFileSystemWrapper::GetVersionTag(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.GetVersionTag(handle); + } + + return caching_handle->GetVersionTag(); +} + +FileType CachingFileSystemWrapper::GetFileType(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.GetFileType(handle); + } + + auto &file_handle = caching_handle->GetFileHandle(); + return underlying_file_system.GetFileType(file_handle); +} + +FileMetadata CachingFileSystemWrapper::Stats(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.Stats(handle); + } + + auto &file_handle = caching_handle->GetFileHandle(); + return underlying_file_system.Stats(file_handle); +} + +//===----------------------------------------------------------------------===// +// Directory Operations +//===----------------------------------------------------------------------===// +bool CachingFileSystemWrapper::DirectoryExists(const string &directory, optional_ptr opener) { + return underlying_file_system.DirectoryExists(directory, opener); +} + +void CachingFileSystemWrapper::CreateDirectory(const string &directory, optional_ptr opener) { + underlying_file_system.CreateDirectory(directory, opener); +} + +void CachingFileSystemWrapper::CreateDirectoriesRecursive(const string &path, optional_ptr opener) { + underlying_file_system.CreateDirectoriesRecursive(path, opener); +} + +void CachingFileSystemWrapper::RemoveDirectory(const string &directory, optional_ptr opener) { + underlying_file_system.RemoveDirectory(directory, opener); +} + +bool CachingFileSystemWrapper::ListFiles(const string &directory, + const std::function &callback, + FileOpener *opener) { + return underlying_file_system.ListFiles(directory, callback, opener); +} + +bool CachingFileSystemWrapper::ListFilesExtended(const string &directory, + const std::function &callback, + optional_ptr opener) { + // Use the public ListFiles API which will internally call `ListFilesExtended` if supported. + return underlying_file_system.ListFiles(directory, callback, opener); +} + +bool CachingFileSystemWrapper::SupportsListFilesExtended() const { + // Cannot delegate to internal filesystem's invocaton since it's `protected`. + return true; +} + +void CachingFileSystemWrapper::MoveFile(const string &source, const string &target, optional_ptr opener) { + underlying_file_system.MoveFile(source, target, opener); +} + +bool CachingFileSystemWrapper::FileExists(const string &filename, optional_ptr opener) { + return underlying_file_system.FileExists(filename, opener); +} + +bool CachingFileSystemWrapper::IsPipe(const string &filename, optional_ptr opener) { + return underlying_file_system.IsPipe(filename, opener); +} + +void CachingFileSystemWrapper::RemoveFile(const string &filename, optional_ptr opener) { + underlying_file_system.RemoveFile(filename, opener); +} + +bool CachingFileSystemWrapper::TryRemoveFile(const string &filename, optional_ptr opener) { + return underlying_file_system.TryRemoveFile(filename, opener); +} + +//===----------------------------------------------------------------------===// +// Path Operations +//===----------------------------------------------------------------------===// +string CachingFileSystemWrapper::GetHomeDirectory() { + return underlying_file_system.GetHomeDirectory(); +} + +string CachingFileSystemWrapper::ExpandPath(const string &path) { + return underlying_file_system.ExpandPath(path); +} + +string CachingFileSystemWrapper::PathSeparator(const string &path) { + return underlying_file_system.PathSeparator(path); +} + +vector CachingFileSystemWrapper::Glob(const string &path, FileOpener *opener) { + return underlying_file_system.Glob(path, opener); +} + +//===----------------------------------------------------------------------===// +// SubSystem Operations +//===----------------------------------------------------------------------===// +void CachingFileSystemWrapper::RegisterSubSystem(unique_ptr sub_fs) { + underlying_file_system.RegisterSubSystem(std::move(sub_fs)); +} + +void CachingFileSystemWrapper::RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) { + underlying_file_system.RegisterSubSystem(compression_type, std::move(fs)); +} + +void CachingFileSystemWrapper::UnregisterSubSystem(const string &name) { + underlying_file_system.UnregisterSubSystem(name); +} + +unique_ptr CachingFileSystemWrapper::ExtractSubSystem(const string &name) { + return underlying_file_system.ExtractSubSystem(name); +} + +vector CachingFileSystemWrapper::ListSubSystems() { + return underlying_file_system.ListSubSystems(); +} + +bool CachingFileSystemWrapper::CanHandleFile(const string &fpath) { + return underlying_file_system.CanHandleFile(fpath); +} + +//===----------------------------------------------------------------------===// +// File Handle Operations +//===----------------------------------------------------------------------===// +void CachingFileSystemWrapper::Seek(FileHandle &handle, idx_t location) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.Seek(handle, location); + } + + caching_handle->Seek(location); +} + +void CachingFileSystemWrapper::Reset(FileHandle &handle) { + Seek(handle, 0); +} + +idx_t CachingFileSystemWrapper::SeekPosition(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.SeekPosition(handle); + } + + return caching_handle->SeekPosition(); +} + +bool CachingFileSystemWrapper::IsManuallySet() { + return underlying_file_system.IsManuallySet(); +} + +bool CachingFileSystemWrapper::CanSeek() { + return underlying_file_system.CanSeek(); +} + +bool CachingFileSystemWrapper::OnDiskFile(FileHandle &handle) { + auto *caching_handle = GetCachingHandleIfPossible(handle); + if (!caching_handle) { + return underlying_file_system.OnDiskFile(handle); + } + + return caching_handle->OnDiskFile(); +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// +unique_ptr CachingFileSystemWrapper::OpenCompressedFile(QueryContext context, unique_ptr handle, + bool write) { + return underlying_file_system.OpenCompressedFile(context, std::move(handle), write); +} + +void CachingFileSystemWrapper::SetDisabledFileSystems(const vector &names) { + underlying_file_system.SetDisabledFileSystems(names); +} + +bool CachingFileSystemWrapper::SubSystemIsDisabled(const string &name) { + return underlying_file_system.SubSystemIsDisabled(name); +} + +bool CachingFileSystemWrapper::IsDisabledForPath(const string &path) { + return underlying_file_system.IsDisabledForPath(path); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/row_group_writer.cpp b/src/duckdb/src/storage/checkpoint/row_group_writer.cpp index 794dddca7..729ef5450 100644 --- a/src/duckdb/src/storage/checkpoint/row_group_writer.cpp +++ b/src/duckdb/src/storage/checkpoint/row_group_writer.cpp @@ -21,8 +21,8 @@ SingleFileRowGroupWriter::SingleFileRowGroupWriter(TableCatalogEntry &table, Par : RowGroupWriter(table, partial_block_manager), writer(writer), table_data_writer(table_data_writer) { } -CheckpointType SingleFileRowGroupWriter::GetCheckpointType() const { - return writer.GetCheckpointType(); +CheckpointOptions SingleFileRowGroupWriter::GetCheckpointOptions() const { + return writer.GetCheckpointOptions(); } WriteStream &SingleFileRowGroupWriter::GetPayloadWriter() { diff --git a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp index 342ea1ff5..be5b910ec 100644 --- a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp +++ b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp @@ -4,6 +4,7 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/table_statistics.hpp" @@ -50,8 +51,8 @@ unique_ptr SingleFileTableDataWriter::GetRowGroupWriter(RowGroup table_data_writer); } -CheckpointType SingleFileTableDataWriter::GetCheckpointType() const { - return checkpoint_manager.GetCheckpointType(); +CheckpointOptions SingleFileTableDataWriter::GetCheckpointOptions() const { + return checkpoint_manager.GetCheckpointOptions(); } MetadataManager &SingleFileTableDataWriter::GetMetadataManager() { @@ -63,6 +64,10 @@ void SingleFileTableDataWriter::WriteUnchangedTable(MetaBlockPointer pointer, id existing_rows = total_rows; } +void SingleFileTableDataWriter::FlushPartialBlocks() { + checkpoint_manager.partial_block_manager.FlushPartialBlocks(); +} + void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) { MetaBlockPointer pointer; @@ -117,17 +122,19 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat if (!v1_0_0_storage) { options.emplace("v1_0_0_storage", v1_0_0_storage); } + auto index_storage_infos = info.GetIndexes().SerializeToDisk(context, options); -#ifdef DUCKDB_BLOCK_VERIFICATION - for (auto &entry : index_storage_infos) { - for (auto &allocator : entry.allocator_infos) { - for (auto &block : allocator.block_pointers) { - checkpoint_manager.verify_block_usage_count[block.block_id]++; + auto debug_verify_blocks = DBConfig::GetSetting(GetDatabase()); + if (debug_verify_blocks) { + for (auto &entry : index_storage_infos) { + for (auto &allocator : entry.allocator_infos) { + for (auto &block : allocator.block_pointers) { + checkpoint_manager.verify_block_usage_count[block.block_id]++; + } } } } -#endif // write empty block pointers for forwards compatibility vector compat_block_pointers; diff --git a/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp b/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp index 935b068e1..48ffb70bc 100644 --- a/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp +++ b/src/duckdb/src/storage/checkpoint/write_overflow_strings_to_disk.cpp @@ -48,18 +48,6 @@ string UncompressedStringSegmentState::GetSegmentInfo() const { return "Overflow String Block Ids: " + result; } -vector UncompressedStringSegmentState::GetAdditionalBlocks() const { - return on_disk_blocks; -} - -void UncompressedStringSegmentState::Cleanup(BlockManager &manager_p) { - auto &manager = block_manager ? *block_manager : manager_p; - for (auto &block_id : on_disk_blocks) { - manager.MarkBlockAsModified(block_id); - } - on_disk_blocks.clear(); -} - void WriteOverflowStringsToDisk::WriteString(UncompressedStringSegmentState &state, string_t string, block_id_t &result_block, int32_t &result_offset) { auto &block_manager = partial_block_manager.GetBlockManager(); @@ -69,7 +57,7 @@ void WriteOverflowStringsToDisk::WriteString(UncompressedStringSegmentState &sta } // first write the length of the string if (block_id == INVALID_BLOCK || offset + 2 * sizeof(uint32_t) >= GetStringSpace()) { - AllocateNewBlock(state, block_manager.GetFreeBlockId()); + AllocateNewBlock(state, partial_block_manager.GetFreeBlockId()); } result_block = block_id; result_offset = UnsafeNumericCast(offset); @@ -96,7 +84,7 @@ void WriteOverflowStringsToDisk::WriteString(UncompressedStringSegmentState &sta D_ASSERT(offset == GetStringSpace()); // there is still remaining stuff to write // now write the current block to disk and allocate a new block - AllocateNewBlock(state, block_manager.GetFreeBlockId()); + AllocateNewBlock(state, partial_block_manager.GetFreeBlockId()); } } } diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp index af361d4bf..22a42d790 100644 --- a/src/duckdb/src/storage/checkpoint_manager.cpp +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -18,10 +18,10 @@ #include "duckdb/main/config.hpp" #include "duckdb/main/connection.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" #include "duckdb/parser/parsed_data/create_schema_info.hpp" #include "duckdb/parser/parsed_data/create_view_info.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/checkpoint/table_data_reader.hpp" @@ -33,16 +33,16 @@ #include "duckdb/catalog/dependency_manager.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/common/thread.hpp" namespace duckdb { void ReorderTableEntries(catalog_entry_vector_t &tables); SingleFileCheckpointWriter::SingleFileCheckpointWriter(QueryContext context, AttachedDatabase &db, - BlockManager &block_manager, CheckpointType checkpoint_type) + BlockManager &block_manager, CheckpointOptions options_p) : CheckpointWriter(db), context(context.GetClientContext()), - partial_block_manager(context, block_manager, PartialBlockType::FULL_CHECKPOINT), - checkpoint_type(checkpoint_type) { + partial_block_manager(context, block_manager, PartialBlockType::FULL_CHECKPOINT), options(options_p) { } BlockManager &SingleFileCheckpointWriter::GetBlockManager() { @@ -148,6 +148,24 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { // get the id of the first meta block auto meta_block = metadata_writer->GetMetaBlockPointer(); + // write a checkpoint flag to the WAL + // in case a crash happens during the checkpoint, we know a checkpoint was instantiated + // we write the root meta block of the planned checkpoint to the WAL + // during recovery we use this: + // * if the root meta block matches the checkpoint entry, we know the checkpoint was completed + // * if the root meta block does not match the checkpoint entry, we know the checkpoint was not completed + // if the checkpoint was completed we don't need to replay the WAL - otherwise we need to replay the WAL + // we also know if a checkpoint was running that we need to check for the checkpoint WAL (`.checkpoint.wal`) + // to replay any concurrent commits that have succeeded and ensure these are not lost + auto &transaction_manager = db.GetTransactionManager().Cast(); + ActiveCheckpointWrapper active_checkpoint(transaction_manager); + auto has_wal = storage_manager.WALStartCheckpoint(meta_block, options); + + auto checkpoint_sleep_ms = DBConfig::GetSetting(db.GetDatabase()); + if (checkpoint_sleep_ms > 0) { + ThreadUtil::SleepMs(checkpoint_sleep_ms); + } + vector> schemas; // we scan the set of committed schemas auto &catalog = Catalog::GetCatalog(db).Cast(); @@ -191,17 +209,6 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { metadata_writer->Flush(); table_metadata_writer->Flush(); - // write a checkpoint flag to the WAL - // this protects against the rare event that the database crashes AFTER writing the file, but BEFORE truncating the - // WAL we write an entry CHECKPOINT "meta_block_id" into the WAL upon loading, if we see there is an entry - // CHECKPOINT "meta_block_id", and the id MATCHES the head idin the file we know that the database was successfully - // checkpointed, so we know that we should avoid replaying the WAL to avoid duplicating data - bool wal_is_empty = storage_manager.GetWALSize() == 0; - if (!wal_is_empty) { - auto wal = storage_manager.GetWAL(); - wal->WriteCheckpoint(meta_block); - wal->Flush(); - } auto debug_checkpoint_abort = DBConfig::GetSetting(db.GetDatabase()); if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER) { throw FatalException("Checkpoint aborted before header write because of PRAGMA checkpoint_abort flag"); @@ -214,18 +221,21 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { header.vector_size = STANDARD_VECTOR_SIZE; block_manager.WriteHeader(context, header); -#ifdef DUCKDB_BLOCK_VERIFICATION - // extend verify_block_usage_count - auto metadata_info = storage_manager.GetMetadataInfo(); - for (auto &info : metadata_info) { - verify_block_usage_count[info.block_id]++; - } - for (auto &entry_ref : catalog_entries) { - auto &entry = entry_ref.get(); - if (entry.type == CatalogType::TABLE_ENTRY) { + auto debug_verify_blocks = DBConfig::GetSetting(db.GetDatabase()); + if (debug_verify_blocks) { + // extend verify_block_usage_count + auto metadata_info = storage_manager.GetMetadataInfo(); + for (auto &info : metadata_info) { + verify_block_usage_count[info.block_id]++; + } + for (auto &entry_ref : catalog_entries) { + auto &entry = entry_ref.get(); + if (entry.type != CatalogType::TABLE_ENTRY) { + continue; + } auto &table = entry.Cast(); auto &storage = table.GetStorage(); - auto segment_info = storage.GetColumnSegmentInfo(); + auto segment_info = storage.GetColumnSegmentInfo(context); for (auto &segment : segment_info) { verify_block_usage_count[segment.block_id]++; if (StringUtil::Contains(segment.segment_info, "Overflow String Block Ids: ")) { @@ -238,9 +248,8 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { } } } + block_manager.VerifyBlocks(verify_block_usage_count); } - block_manager.VerifyBlocks(verify_block_usage_count); -#endif if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); @@ -249,9 +258,13 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { // truncate the file block_manager.Truncate(); + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_WAL_FINISH) { + throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); + } + // truncate the WAL - if (!wal_is_empty) { - storage_manager.ResetWAL(); + if (has_wal) { + storage_manager.WALFinishCheckpoint(); } } @@ -537,14 +550,21 @@ void SingleFileCheckpointWriter::WriteTable(TableCatalogEntry &table, Serializer // Write the table metadata serializer.WriteProperty(100, "table", &table); + // If there is a context available, bind indexes before serialization. + // This is necessary so that buffered index operations are replayed before we checkpoint, otherwise + // we would lose them if there was a restart after this. + if (context && context->transaction.HasActiveTransaction()) { + auto &info = table.GetStorage().GetDataTableInfo(); + info->BindIndexes(*context); + } + // FIXME: If we do not have a context, however, the unbound indexes have to be serialized to disk. + // Write the table data auto table_lock = table.GetStorage().GetCheckpointLock(); - if (auto writer = GetTableDataWriter(table)) { + auto writer = GetTableDataWriter(table); + if (writer) { writer->WriteTableData(serializer); } - // flush any partial blocks BEFORE releasing the table lock - // flushing partial blocks updates where data lives and is not thread-safe - partial_block_manager.FlushPartialBlocks(); } void CheckpointReader::ReadTable(CatalogTransaction transaction, Deserializer &deserializer) { @@ -566,7 +586,6 @@ void CheckpointReader::ReadTable(CatalogTransaction transaction, Deserializer &d void CheckpointReader::ReadTableData(CatalogTransaction transaction, Deserializer &deserializer, BoundCreateTableInfo &bound_info) { - // written in "SingleFileTableDataWriter::FinalizeTable" auto table_pointer = deserializer.ReadProperty(101, "table_pointer"); auto total_rows = deserializer.ReadProperty(102, "total_rows"); diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp index fa1ffaeba..a5f3ccc35 100644 --- a/src/duckdb/src/storage/compression/bitpacking.cpp +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -19,6 +19,7 @@ namespace duckdb { +constexpr const idx_t BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; static constexpr const idx_t BITPACKING_METADATA_GROUP_SIZE = STANDARD_VECTOR_SIZE > 512 ? STANDARD_VECTOR_SIZE : 2048; BitpackingMode BitpackingModeFromString(const string &str) { @@ -70,7 +71,7 @@ static bitpacking_metadata_encoded_t EncodeMeta(bitpacking_metadata_t metadata) } static bitpacking_metadata_t DecodeMeta(bitpacking_metadata_encoded_t *metadata_encoded) { bitpacking_metadata_t metadata; - metadata.mode = Load(data_ptr_cast(metadata_encoded) + 3); + metadata.mode = static_cast((*metadata_encoded >> 24) & 0xFF); metadata.offset = *metadata_encoded & 0x00FFFFFF; return metadata; } @@ -124,6 +125,9 @@ struct BitpackingState { bool all_valid; bool all_invalid; + bool has_valid; + bool has_invalid; + bool can_do_delta; bool can_do_for; @@ -139,6 +143,8 @@ struct BitpackingState { delta_offset = 0; all_valid = true; all_invalid = true; + has_valid = false; + has_invalid = false; can_do_delta = false; can_do_for = false; compression_buffer_idx = 0; @@ -299,6 +305,8 @@ struct BitpackingState { template bool Update(T value, bool is_valid) { compression_buffer_validity[compression_buffer_idx] = is_valid; + has_valid = has_valid || is_valid; + has_invalid = has_invalid || !is_valid; all_valid = all_valid && is_valid; all_invalid = all_invalid && !is_valid; @@ -341,8 +349,6 @@ unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalTyp template bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = state.Cast>(); - // We use BITPACKING_METADATA_GROUP_SIZE tuples, which can exceed the block size. // In that case, we disable bitpacking. // we are conservative here by multiplying by 2 @@ -351,6 +357,7 @@ bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { return false; } + auto &analyze_state = state.Cast>(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(count, vdata); @@ -383,7 +390,7 @@ struct BitpackingCompressionState : public CompressionState { explicit BitpackingCompressionState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info) : CompressionState(info), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_BITPACKING)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); state.data_ptr = reinterpret_cast(this); @@ -482,9 +489,18 @@ struct BitpackingCompressionState : public CompressionState { static void UpdateStats(BitpackingCompressionState *state, idx_t count) { state->current_segment->count += count; - if (WRITE_STATISTICS && !state->state.all_invalid) { - state->current_segment->stats.statistics.template UpdateNumericStats(state->state.maximum); - state->current_segment->stats.statistics.template UpdateNumericStats(state->state.minimum); + if (WRITE_STATISTICS) { + if (state->state.has_valid) { + state->current_segment->stats.statistics.SetHasNoNullFast(); + } + if (state->state.has_invalid) { + state->current_segment->stats.statistics.SetHasNullFast(); + } + + if (!state->state.all_invalid) { + state->current_segment->stats.statistics.template UpdateNumericStats(state->state.maximum); + state->current_segment->stats.statistics.template UpdateNumericStats(state->state.minimum); + } } } }; @@ -497,12 +513,12 @@ struct BitpackingCompressionState : public CompressionState { info.GetBlockSize() - BitpackingPrimitives::BITPACKING_HEADER_SIZE; } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); auto &buffer_manager = BufferManager::GetBufferManager(db); @@ -524,9 +540,8 @@ struct BitpackingCompressionState : public CompressionState { void FlushAndCreateSegmentIfFull(idx_t required_data_bytes, idx_t required_meta_bytes) { if (!CanStore(required_data_bytes, required_meta_bytes)) { - idx_t row_start = current_segment->start + current_segment->count; FlushSegment(); - CreateEmptySegment(row_start); + CreateEmptySegment(); } } @@ -629,9 +644,9 @@ static T DeltaDecode(T *data, T previous_value, const size_t size) { template ::type> struct BitpackingScanState : public SegmentScanState { public: - explicit BitpackingScanState(ColumnSegment &segment) : current_segment(segment) { + explicit BitpackingScanState(const QueryContext &context, ColumnSegment &segment) : current_segment(segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(context, segment.block); auto data_ptr = handle.Ptr(); // load offset to bitpacking widths pointer @@ -720,7 +735,6 @@ struct BitpackingScanState : public SegmentScanState { // This skips straight to the correct metadata group idx_t meta_groups_to_skip = (skip_count + current_group_offset) / BITPACKING_METADATA_GROUP_SIZE; if (meta_groups_to_skip) { - // bitpacking_metadata_ptr points to the next metadata: this means we need to advance the pointer by n-1 bitpacking_metadata_ptr -= (meta_groups_to_skip - 1) * sizeof(bitpacking_metadata_encoded_t); LoadNextGroup(); @@ -782,8 +796,8 @@ struct BitpackingScanState : public SegmentScanState { }; template -unique_ptr BitpackingInitScan(ColumnSegment &segment) { - auto result = make_uniq>(segment); +unique_ptr BitpackingInitScan(const QueryContext &context, ColumnSegment &segment) { + auto result = make_uniq>(context, segment); return std::move(result); } @@ -892,7 +906,7 @@ void BitpackingScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_c template void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(state.context, segment); scan_state.Skip(segment, NumericCast(row_id)); D_ASSERT(scan_state.current_group_offset < BITPACKING_METADATA_GROUP_SIZE); @@ -956,10 +970,10 @@ void BitpackingSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_c // GetSegmentInfo //===--------------------------------------------------------------------===// template -InsertionOrderPreservingMap BitpackingGetSegmentInfo(ColumnSegment &segment) { +InsertionOrderPreservingMap BitpackingGetSegmentInfo(QueryContext context, ColumnSegment &segment) { map counts; auto tuple_count = segment.count.load(); - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(context, segment); for (idx_t i = 0; i < tuple_count; i += BITPACKING_METADATA_GROUP_SIZE) { if (i) { scan_state.LoadNextGroup(); diff --git a/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp index 3c84c6ec6..c363c9280 100644 --- a/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp +++ b/src/duckdb/src/storage/compression/bitpacking_hugeint.cpp @@ -119,7 +119,6 @@ static void UnpackDelta128(const uint32_t *__restrict in, uhugeint_t *__restrict static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t delta, uint16_t shl, uhugeint_t mask) { if (delta + shl < 32) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { @@ -127,7 +126,6 @@ static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t } } else if (delta + shl >= 32 && delta + shl < 64) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { @@ -141,7 +139,6 @@ static void PackSingle(const uhugeint_t in, uint32_t *__restrict &out, uint16_t } else if (delta + shl >= 64 && delta + shl < 96) { - if (shl == 0) { out[0] = static_cast(in & mask); } else { diff --git a/src/duckdb/src/storage/compression/dict_fsst.cpp b/src/duckdb/src/storage/compression/dict_fsst.cpp index c43567c52..9e6e2f279 100644 --- a/src/duckdb/src/storage/compression/dict_fsst.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst.cpp @@ -56,7 +56,7 @@ struct DictFSSTCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -111,12 +111,15 @@ void DictFSSTCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictFSSTCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictFSSTCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(segment, buffer_manager.Pin(segment.block)); state->Initialize(true); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); } @@ -130,7 +133,7 @@ void DictFSSTCompressionStorage::StringScanPartial(ColumnSegment &segment, Colum // clear any previously locked buffers and get the primary buffer handle auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); if (!ALLOW_DICT_VECTORS || !scan_state.AllowDictionaryScan(scan_count)) { scan_state.ScanToFlatVector(result, result_offset, start, scan_count); } else { @@ -162,7 +165,7 @@ void DictFSSTSelect(ColumnSegment &segment, ColumnScanState &state, idx_t vector auto &scan_state = state.scan_state->Cast(); if (scan_state.mode == DictFSSTMode::FSST_ONLY) { // for FSST only - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); scan_state.Select(result, start, sel, sel_count); return; } @@ -178,7 +181,7 @@ static void DictFSSTFilter(ColumnSegment &segment, ColumnScanState &state, idx_t SelectionVector &sel, idx_t &sel_count, const TableFilter &filter, TableFilterState &filter_state) { auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); if (scan_state.AllowDictionaryScan(vector_count)) { // only pushdown filters on dictionaries if (!scan_state.filter_result) { @@ -187,12 +190,13 @@ static void DictFSSTFilter(ColumnSegment &segment, ColumnScanState &state, idx_t scan_state.filter_result = make_unsafe_uniq_array(scan_state.dict_count); // apply the filter + auto &dict_data = scan_state.dictionary->data; UnifiedVectorFormat vdata; - scan_state.dictionary->ToUnifiedFormat(scan_state.dict_count, vdata); + dict_data.ToUnifiedFormat(scan_state.dict_count, vdata); SelectionVector dict_sel; idx_t filter_count = scan_state.dict_count; - ColumnSegment::FilterSelection(dict_sel, *scan_state.dictionary, vdata, filter, filter_state, - scan_state.dict_count, filter_count); + ColumnSegment::FilterSelection(dict_sel, dict_data, vdata, filter, filter_state, scan_state.dict_count, + filter_count); // now set all matching tuples to true for (idx_t i = 0; i < filter_count; i++) { @@ -217,8 +221,7 @@ static void DictFSSTFilter(ColumnSegment &segment, ColumnScanState &state, idx_t } sel_count = approved_tuple_count; - result.Dictionary(*(scan_state.dictionary), scan_state.dict_count, dict_sel, vector_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(scan_state.dictionary, dict_sel); return; } // fallback: scan + filter diff --git a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp index 580a5cfc5..9c7a98251 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp @@ -4,6 +4,10 @@ #include "fsst.h" #include "duckdb/common/fsst.hpp" +#if defined(__MVS__) && !defined(alloca) +#define alloca __builtin_alloca +#endif + namespace duckdb { namespace dict_fsst { @@ -11,8 +15,13 @@ DictFSSTCompressionState::DictFSSTCompressionState(ColumnDataCheckpointData &che unique_ptr &&analyze_p) : CompressionState(analyze_p->info), checkpoint_data(checkpoint_data_p), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICT_FSST)), + current_string_map( + info.GetBlockManager().buffer_manager.GetBufferAllocator(), + MinValue(analyze_p.get()->total_count, info.GetBlockSize()) / 2, // maximum_size_p (amount of elements) + 1 // maximum_target_capacity_p (byte capacity) + ), analyze(std::move(analyze_p)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); } DictFSSTCompressionState::~DictFSSTCompressionState() { @@ -228,12 +237,12 @@ void DictFSSTCompressionState::FlushEncodingBuffer() { dictionary_encoding_buffer.clear(); } -void DictFSSTCompressionState::CreateEmptySegment(idx_t row_start) { +void DictFSSTCompressionState::CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, info.GetBlockSize(), - info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); // Reset the pointers into the current segment. @@ -251,7 +260,7 @@ void DictFSSTCompressionState::CreateEmptySegment(idx_t row_start) { D_ASSERT(string_lengths.empty()); string_lengths.push_back(0); dict_count = 1; - D_ASSERT(current_string_map.empty()); + D_ASSERT(current_string_map.GetSize() == 0); symbol_table_size = DConstants::INVALID_INDEX; dictionary_offset = 0; @@ -268,7 +277,6 @@ void DictFSSTCompressionState::Flush(bool final) { current_segment->count = tuple_count; - auto next_start = current_segment->start + current_segment->count; auto segment_size = Finalize(); auto &state = checkpoint_data.GetCheckpointState(); state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); @@ -280,11 +288,7 @@ void DictFSSTCompressionState::Flush(bool final) { D_ASSERT(dictionary_encoding_buffer.empty()); D_ASSERT(to_encode_string_sum == 0); - auto old_size = current_string_map.size(); - current_string_map.clear(); - if (!final) { - current_string_map.reserve(old_size); - } + current_string_map.Clear(); string_lengths.clear(); dictionary_indices.clear(); if (encoder) { @@ -296,7 +300,7 @@ void DictFSSTCompressionState::Flush(bool final) { total_tuple_count += tuple_count; if (!final) { - CreateEmptySegment(next_start); + CreateEmptySegment(); } } @@ -444,7 +448,7 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string } state.to_encode_string_sum += str_len; auto &uncompressed_string = state.dictionary_encoding_buffer.back(); - state.current_string_map[uncompressed_string] = state.dict_count; + state.current_string_map.Insert(uncompressed_string); } else { state.string_lengths.push_back(str_len); auto baseptr = @@ -452,7 +456,7 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string memcpy(baseptr + state.dictionary_offset, str.GetData(), str_len); string_t dictionary_string((const char *)(baseptr + state.dictionary_offset), str_len); // NOLINT state.dictionary_offset += str_len; - state.current_string_map[dictionary_string] = state.dict_count; + state.current_string_map.Insert(dictionary_string); } state.dict_count++; @@ -490,8 +494,8 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form if (append_state == DictionaryAppendState::ENCODED_ALL_UNIQUE || is_null) { lookup = 0; } else { - auto it = current_string_map.find(str); - lookup = it == current_string_map.end() ? DConstants::INVALID_INDEX : it->second; + auto it = current_string_map.Lookup(str); + lookup = it.IsEmpty() ? DConstants::INVALID_INDEX : it.index + 1; } switch (append_state) { @@ -785,8 +789,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { #endif // Rewrite the dictionary - current_string_map.clear(); - current_string_map.reserve(dict_count); + current_string_map.Clear(); if (new_state == DictionaryAppendState::ENCODED) { offset = 0; auto uncompressed_dictionary_ptr = dict_copy.GetData(); @@ -797,7 +800,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { auto uncompressed_str_len = string_lengths[dictionary_index]; string_t dictionary_string(uncompressed_dictionary_ptr + offset, uncompressed_str_len); - current_string_map.insert({dictionary_string, dictionary_index}); + current_string_map.Insert(dictionary_string); #ifdef DEBUG //! Verify that we can decompress the string @@ -822,7 +825,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { string_lengths[dictionary_index] = size; string_t dictionary_string((const char *)start, UnsafeNumericCast(size)); // NOLINT - current_string_map.insert({dictionary_string, dictionary_index}); + current_string_map.Insert(dictionary_string); } } dictionary_offset = new_size; @@ -863,6 +866,8 @@ void DictFSSTCompressionState::Compress(Vector &scan_vector, idx_t count) { } while (false); if (!is_null) { UncompressedStringStorage::UpdateStringStats(current_segment->stats, str); + } else { + current_segment->stats.statistics.SetHasNullFast(); } tuple_count++; } diff --git a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp index 0546096bb..f6befffe5 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp @@ -98,17 +98,18 @@ void CompressedStringScanState::Initialize(bool initialize_dictionary) { return; } - dictionary = make_buffer(segment.type, dict_count); - auto dict_child_data = FlatVector::GetData(*(dictionary)); - auto &validity = FlatVector::Validity(*dictionary); + dictionary = DictionaryVector::CreateReusableDictionary(segment.type, dict_count); + auto dict_child_data = FlatVector::GetData(dictionary->data); + auto &validity = FlatVector::Validity(dictionary->data); D_ASSERT(dict_count >= 1); validity.SetInvalid(0); + auto &dict_data = dictionary->data; uint32_t offset = 0; for (uint32_t i = 0; i < dict_count; i++) { //! We can uncompress during fetching, we need the length of the string inside the dictionary auto string_len = string_lengths[i]; - dict_child_data[i] = FetchStringFromDict(*dictionary, offset, i); + dict_child_data[i] = FetchStringFromDict(dict_data, offset, i); offset += string_len; } } @@ -158,7 +159,7 @@ void CompressedStringScanState::ScanToFlatVector(Vector &result, idx_t result_of if (dictionary) { // We have prepared the full dictionary, we can reference these strings directly - auto dictionary_values = FlatVector::GetData(*dictionary); + auto dictionary_values = FlatVector::GetData(dictionary->data); for (idx_t i = 0; i < scan_count; i++) { // Lookup dict offset in index buffer auto string_number = selvec.get_index(i + start_offset); @@ -223,8 +224,7 @@ void CompressedStringScanState::ScanToDictionaryVector(ColumnSegment &segment, V D_ASSERT(result_offset == 0); auto &selvec = GetSelVec(start, scan_count); - result.Dictionary(*(dictionary), dict_count, selvec, scan_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(dictionary, selvec); result.Verify(result_offset + scan_count); } diff --git a/src/duckdb/src/storage/compression/dictionary/analyze.cpp b/src/duckdb/src/storage/compression/dictionary/analyze.cpp index 3d12bc2e1..538ad543c 100644 --- a/src/duckdb/src/storage/compression/dictionary/analyze.cpp +++ b/src/duckdb/src/storage/compression/dictionary/analyze.cpp @@ -44,10 +44,14 @@ bool DictionaryAnalyzeState::CalculateSpaceRequirements(bool new_string, idx_t s void DictionaryAnalyzeState::Flush(bool final) { segment_count++; current_tuple_count = 0; + max_unique_count_across_segments = MaxValue(max_unique_count_across_segments, current_unique_count); current_unique_count = 0; current_dict_size = 0; current_set.clear(); } +void DictionaryAnalyzeState::UpdateMaxUniqueCount() { + max_unique_count_across_segments = MaxValue(max_unique_count_across_segments, current_unique_count); +} void DictionaryAnalyzeState::Verify() { } diff --git a/src/duckdb/src/storage/compression/dictionary/compression.cpp b/src/duckdb/src/storage/compression/dictionary/compression.cpp index 48b02a42a..72e424878 100644 --- a/src/duckdb/src/storage/compression/dictionary/compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/compression.cpp @@ -1,25 +1,31 @@ #include "duckdb/storage/compression/dictionary/compression.hpp" -#include "duckdb/storage/segment/uncompressed.hpp" namespace duckdb { DictionaryCompressionCompressState::DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, - const CompressionInfo &info) + const CompressionInfo &info, + const idx_t max_unique_count_across_all_segments) : DictionaryCompressionState(info), checkpoint_data(checkpoint_data_p), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)), + current_string_map( + info.GetBlockManager().buffer_manager.GetBufferAllocator(), + max_unique_count_across_all_segments * 2, // * 2 results in less linear probing, improving performance + 1 // maximum_target_capacity_p, 1 because we don't care about target for our use-case, as we + // only use PrimitiveDictionary for duplicate checks, and not for writing to any target + ) { + CreateEmptySegment(); } -void DictionaryCompressionCompressState::CreateEmptySegment(idx_t row_start) { +void DictionaryCompressionCompressState::CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, info.GetBlockSize(), - info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); // Reset the buffers and the string map. - current_string_map.clear(); + current_string_map.Clear(); index_buffer.clear(); // Reserve index 0 for null strings. @@ -42,15 +48,14 @@ void DictionaryCompressionCompressState::Verify() { D_ASSERT(DictionaryCompression::HasEnoughSpace(current_segment->count.load(), index_buffer.size(), current_dictionary.size, current_width, info.GetBlockSize())); D_ASSERT(current_dictionary.end == info.GetBlockSize()); - D_ASSERT(index_buffer.size() == current_string_map.size() + 1); // +1 is for null value + D_ASSERT(index_buffer.size() == current_string_map.GetSize() + 1); // +1 is for null value } bool DictionaryCompressionCompressState::LookupString(string_t str) { - auto search = current_string_map.find(str); - auto has_result = search != current_string_map.end(); - + const auto &entry = current_string_map.Lookup(str); + const auto has_result = !entry.IsEmpty(); if (has_result) { - latest_lookup_result = search->second; + latest_lookup_result = entry.index + 1; } return has_result; } @@ -69,11 +74,11 @@ void DictionaryCompressionCompressState::AddNewString(string_t str) { index_buffer.push_back(current_dictionary.size); selection_buffer.push_back(UnsafeNumericCast(index_buffer.size() - 1)); if (str.IsInlined()) { - current_string_map.insert({str, index_buffer.size() - 1}); + current_string_map.Insert(str); } else { string_t dictionary_string((const char *)dict_pos, UnsafeNumericCast(str.GetSize())); // NOLINT D_ASSERT(!dictionary_string.IsInlined()); - current_string_map.insert({dictionary_string, index_buffer.size() - 1}); + current_string_map.Insert(dictionary_string); } DictionaryCompression::SetDictionary(*current_segment, current_handle, current_dictionary); @@ -82,6 +87,7 @@ void DictionaryCompressionCompressState::AddNewString(string_t str) { } void DictionaryCompressionCompressState::AddNull() { + current_segment->stats.statistics.SetHasNullFast(); selection_buffer.push_back(0); current_segment->count++; } @@ -103,14 +109,12 @@ bool DictionaryCompressionCompressState::CalculateSpaceRequirements(bool new_str } void DictionaryCompressionCompressState::Flush(bool final) { - auto next_start = current_segment->start + current_segment->count; - auto segment_size = Finalize(); auto &state = checkpoint_data.GetCheckpointState(); state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); if (!final) { - CreateEmptySegment(next_start); + CreateEmptySegment(); } } diff --git a/src/duckdb/src/storage/compression/dictionary/decompression.cpp b/src/duckdb/src/storage/compression/dictionary/decompression.cpp index 51f6945e2..6e389d026 100644 --- a/src/duckdb/src/storage/compression/dictionary/decompression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/decompression.cpp @@ -48,10 +48,10 @@ void CompressedStringScanState::Initialize(ColumnSegment &segment, bool initiali return; } - dictionary = make_buffer(segment.type, index_buffer_count); + dictionary = DictionaryVector::CreateReusableDictionary(segment.type, index_buffer_count); dictionary_size = index_buffer_count; - auto dict_child_data = FlatVector::GetData(*(dictionary)); - FlatVector::SetNull(*dictionary, 0, true); + auto dict_child_data = FlatVector::GetData(dictionary->data); + FlatVector::SetNull(dictionary->data, 0, true); for (uint32_t i = 1; i < index_buffer_count; i++) { // NOTE: the passing of dict_child_vector, will not be used, its for big strings uint16_t str_len = GetStringLength(i); @@ -114,8 +114,7 @@ void CompressedStringScanState::ScanToDictionaryVector(ColumnSegment &segment, V } } - result.Dictionary(*(dictionary), dictionary_size, *sel_vec, scan_count); - DictionaryVector::SetDictionaryId(result, to_string(CastPointerToValue(&segment))); + result.Dictionary(dictionary, *sel_vec); } } // namespace duckdb diff --git a/src/duckdb/src/storage/compression/dictionary_compression.cpp b/src/duckdb/src/storage/compression/dictionary_compression.cpp index fa027edd9..6f10acb60 100644 --- a/src/duckdb/src/storage/compression/dictionary_compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary_compression.cpp @@ -4,8 +4,6 @@ #include "duckdb/common/bitpacking.hpp" #include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/string_map_set.hpp" #include "duckdb/common/types/vector_buffer.hpp" #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" @@ -57,7 +55,7 @@ struct DictionaryCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -89,6 +87,10 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { auto &analyze_state = state_p.Cast(); auto &state = *analyze_state.analyze_state; + if (state.current_tuple_count != 0) { + state.UpdateMaxUniqueCount(); + } + auto width = BitpackingPrimitives::MinimumBitWidth(state.current_unique_count + 1); auto req_space = DictionaryCompression::RequiredSpace(state.current_tuple_count, state.current_unique_count, state.current_dict_size, width); @@ -102,7 +104,10 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { //===--------------------------------------------------------------------===// unique_ptr DictionaryCompressionStorage::InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state) { - return make_uniq(checkpoint_data, state->info); + const auto &analyze_state = state->Cast(); + auto &actual_state = *analyze_state.analyze_state; + return make_uniq(checkpoint_data, state->info, + actual_state.max_unique_count_across_segments); } void DictionaryCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { @@ -118,7 +123,8 @@ void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictionaryCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictionaryCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(buffer_manager.Pin(segment.block)); state->Initialize(segment, true); @@ -134,7 +140,7 @@ void DictionaryCompressionStorage::StringScanPartial(ColumnSegment &segment, Col // clear any previously locked buffers and get the primary buffer handle auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); if (!ALLOW_DICT_VECTORS || scan_count != STANDARD_VECTOR_SIZE) { scan_state.ScanToFlatVector(result, result_offset, start, scan_count); } else { diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp index afd335dab..516bf99fe 100644 --- a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -46,7 +46,7 @@ struct UncompressedCompressState : public CompressionState { UncompressedCompressState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info); public: - virtual void CreateEmptySegment(idx_t row_start); + virtual void CreateEmptySegment(); void FlushSegment(idx_t segment_size); void Finalize(idx_t segment_size); @@ -61,15 +61,15 @@ UncompressedCompressState::UncompressedCompressState(ColumnDataCheckpointData &c const CompressionInfo &info) : CompressionState(info), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED)) { - UncompressedCompressState::CreateEmptySegment(checkpoint_data.GetRowGroup().start); + UncompressedCompressState::CreateEmptySegment(); } -void UncompressedCompressState::CreateEmptySegment(idx_t row_start) { +void UncompressedCompressState::CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, info.GetBlockSize(), - info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); if (type.InternalType() == PhysicalType::VARCHAR) { auto &state = compressed_segment->GetSegmentState()->Cast(); auto &storage_manager = checkpoint_data.GetStorageManager(); @@ -120,12 +120,11 @@ void UncompressedFunctions::Compress(CompressionState &state_p, Vector &data, id // appended everything: finished return; } - auto next_start = state.current_segment->start + state.current_segment->count; // the segment is full: flush it to disk state.FlushSegment(state.current_segment->FinalizeAppend(state.append_state)); // now create a new segment and continue appending - state.CreateEmptySegment(next_start); + state.CreateEmptySegment(); offset += appended; count -= appended; } @@ -143,10 +142,10 @@ struct FixedSizeScanState : public SegmentScanState { BufferHandle handle; }; -unique_ptr FixedSizeInitScan(ColumnSegment &segment) { +unique_ptr FixedSizeInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); + result->handle = buffer_manager.Pin(context, segment.block); return std::move(result); } @@ -157,7 +156,7 @@ template void FixedSizeScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto source_data = data + start * sizeof(T); @@ -170,7 +169,7 @@ void FixedSizeScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t template void FixedSizeScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { auto &scan_state = state.scan_state->template Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto source_data = data + start * sizeof(T); @@ -215,15 +214,18 @@ struct StandardFixedSizeAppend { auto target_idx = target_offset + i; bool is_null = !adata.validity.RowIsValid(source_idx); if (!is_null) { + stats.statistics.SetHasNoNullFast(); stats.statistics.UpdateNumericStats(sdata[source_idx]); tdata[target_idx] = sdata[source_idx]; } else { + stats.statistics.SetHasNullFast(); // we insert a NullValue in the null gap for debuggability // this value should never be used or read anywhere tdata[target_idx] = NullValue(); } } } else { + stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < count; i++) { auto source_idx = adata.sel->get_index(offset + i); auto target_idx = target_offset + i; diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index cbb3b3ac7..77d0d5bb4 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -50,7 +50,7 @@ struct FSSTStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -219,7 +219,7 @@ class FSSTCompressionState : public CompressionState { FSSTCompressionState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info) : CompressionState(info), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_FSST)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); } ~FSSTCompressionState() override { @@ -241,12 +241,12 @@ class FSSTCompressionState : public CompressionState { current_end_ptr = current_handle.Ptr() + current_dictionary.end; } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); Reset(); } @@ -276,7 +276,7 @@ class FSSTCompressionState : public CompressionState { current_segment->count++; } - void AddNull() { + void AddEmptyStringInternal() { if (!HasEnoughSpace(0)) { Flush(); if (!HasEnoughSpace(0)) { @@ -287,8 +287,13 @@ class FSSTCompressionState : public CompressionState { current_segment->count++; } + void AddNull() { + AddEmptyStringInternal(); + current_segment->stats.statistics.SetHasNullFast(); + } + void AddEmptyString() { - AddNull(); + AddEmptyStringInternal(); UncompressedStringStorage::UpdateStringStats(current_segment->stats, ""); } @@ -323,14 +328,12 @@ class FSSTCompressionState : public CompressionState { } void Flush(bool final = false) { - auto next_start = current_segment->start + current_segment->count; - auto segment_size = Finalize(); auto &state = checkpoint_data.GetCheckpointState(); state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); if (!final) { - CreateEmptySegment(next_start); + CreateEmptySegment(); } } @@ -450,7 +453,8 @@ void FSSTStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t auto idx = vdata.sel->get_index(i); // Note: we treat nulls and empty strings the same - if (!vdata.validity.RowIsValid(idx) || data[idx].GetSize() == 0) { + const bool is_null = !vdata.validity.RowIsValid(idx); + if (is_null || data[idx].GetSize() == 0) { continue; } @@ -569,7 +573,7 @@ struct FSSTScanState : public StringScanState { } }; -unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr FSSTStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto block_size = segment.GetBlockManager().GetBlockSize(); auto string_block_limit = StringUncompressed::GetStringBlockLimit(block_size); auto state = make_uniq(string_block_limit); @@ -585,8 +589,9 @@ unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) } state->duckdb_fsst_decoder_ptr = state->duckdb_fsst_decoder.get(); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); @@ -640,9 +645,8 @@ void FSSTStorage::EndScan(FSSTScanState &scan_state, bp_delta_offsets_t &offsets template void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { - auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); bool enable_fsst_vectors; if (ALLOW_FSST_VECTORS) { @@ -710,7 +714,7 @@ void FSSTStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx void FSSTStorage::Select(ColumnSegment &segment, ColumnScanState &state, idx_t vector_count, Vector &result, const SelectionVector &sel, idx_t sel_count) { auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto dict = GetDictionary(segment, scan_state.handle); @@ -734,7 +738,6 @@ void FSSTStorage::Select(ColumnSegment &segment, ColumnScanState &state, idx_t v //===--------------------------------------------------------------------===// void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto handle = buffer_manager.Pin(segment.block); auto base_ptr = handle.Ptr() + segment.GetBlockOffset(); diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp index a4d1e789b..f9cc79b47 100644 --- a/src/duckdb/src/storage/compression/numeric_constant.cpp +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -1,17 +1,19 @@ #include "duckdb/common/types/vector.hpp" #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" #include "duckdb/storage/segment/uncompressed.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" namespace duckdb { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr ConstantInitScan(ColumnSegment &segment) { +unique_ptr ConstantInitScan(const QueryContext &context, ColumnSegment &segment) { return nullptr; } @@ -105,14 +107,16 @@ void ConstantSelect(ColumnSegment &segment, ColumnScanState &state, idx_t vector //===--------------------------------------------------------------------===// // Filter //===--------------------------------------------------------------------===// -void FiltersNullValues(const LogicalType &type, const TableFilter &filter, bool &filters_nulls, - bool &filters_valid_values, TableFilterState &filter_state) { +void ConstantFun::FiltersNullValues(const LogicalType &type, const TableFilter &filter, bool &filters_nulls, + bool &filters_valid_values, TableFilterState &filter_state) { filters_nulls = false; filters_valid_values = false; switch (filter.filter_type) { - case TableFilterType::OPTIONAL_FILTER: - break; + case TableFilterType::OPTIONAL_FILTER: { + auto &opt_filter = filter.Cast(); + return opt_filter.FiltersNullValues(type, filters_nulls, filters_valid_values, filter_state); + } case TableFilterType::CONJUNCTION_OR: { auto &conjunction_or = filter.Cast(); auto &state = filter_state.Cast(); @@ -160,6 +164,11 @@ void FiltersNullValues(const LogicalType &type, const TableFilter &filter, bool filters_valid_values = false; break; } + case TableFilterType::BLOOM_FILTER: { + auto &bf = filter.Cast(); + filters_nulls = bf.FiltersNullValues(); + break; + } default: throw InternalException("FIXME: unsupported type for filter selection in validity select"); } @@ -170,7 +179,7 @@ void ConstantFilterValidity(ColumnSegment &segment, ColumnScanState &state, idx_ TableFilterState &filter_state) { // check what effect the filter has on NULL values bool filters_nulls, filters_valid_values; - FiltersNullValues(result.GetType(), filter, filters_nulls, filters_valid_values, filter_state); + ConstantFun::FiltersNullValues(result.GetType(), filter, filters_nulls, filters_valid_values, filter_state); auto &stats = segment.stats.statistics; if (stats.CanHaveNull()) { diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index 57ebaf1fa..1b4699259 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -140,18 +140,18 @@ struct RLECompressState : public CompressionState { RLECompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info) : CompressionState(info), checkpoint_data(checkpoint_data_p), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_RLE)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); state.dataptr = (void *)this; max_rle_count = MaxRLECount(); } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto column_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, info.GetBlockSize(), - info.GetBlockManager()); + auto column_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(column_segment); auto &buffer_manager = BufferManager::GetBufferManager(db); @@ -176,16 +176,20 @@ struct RLECompressState : public CompressionState { entry_count++; // update meta data - if (WRITE_STATISTICS && !is_null) { - current_segment->stats.statistics.UpdateNumericStats(value); + if (WRITE_STATISTICS) { + if (!is_null) { + current_segment->stats.statistics.SetHasNoNullFast(); + current_segment->stats.statistics.UpdateNumericStats(value); + } else { + current_segment->stats.statistics.SetHasNullFast(); + } } current_segment->count += count; if (entry_count == max_rle_count) { // we have finished writing this segment: flush it and create a new segment - auto row_start = current_segment->start + current_segment->count; FlushSegment(); - CreateEmptySegment(row_start); + CreateEmptySegment(); entry_count = 0; } } @@ -303,7 +307,7 @@ struct RLEScanState : public SegmentScanState { }; template -unique_ptr RLEInitScan(ColumnSegment &segment) { +unique_ptr RLEInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq>(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/roaring/analyze.cpp b/src/duckdb/src/storage/compression/roaring/analyze.cpp index 76cfa849e..5332b0333 100644 --- a/src/duckdb/src/storage/compression/roaring/analyze.cpp +++ b/src/duckdb/src/storage/compression/roaring/analyze.cpp @@ -13,7 +13,6 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/segment/uncompressed.hpp" #include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/bitpacking.hpp" namespace duckdb { @@ -167,13 +166,35 @@ void RoaringAnalyzeState::FlushContainer() { count = 0; } -void RoaringAnalyzeState::Analyze(Vector &input, idx_t count) { +template <> +void RoaringAnalyzeState::Analyze(Vector &input, idx_t count) { auto &self = *this; - RoaringStateAppender::AppendVector(self, input, count); total_count += count; } +template <> +void RoaringAnalyzeState::Analyze(Vector &input, idx_t count) { + auto &self = *this; + input.Flatten(count); + Vector bitpacked_vector(LogicalType::UBIGINT, count); + auto &bitpacked_vector_validity = FlatVector::Validity(bitpacked_vector); + bitpacked_vector_validity.EnsureWritable(); + auto dst = data_ptr_cast(bitpacked_vector_validity.GetData()); + const bool *src = FlatVector::GetData(input); + const auto &validity = FlatVector::Validity(input); + if (validity.AllValid()) { + BitPackBooleans(dst, src, count); + } else { + BitPackBooleans(dst, src, count, &validity); + } + + // Bitpack the booleans, so they can be fed through the current compression code, with the same format as a validity + // mask. + RoaringStateAppender::AppendVector(self, bitpacked_vector, count); + total_count += count; +} + } // namespace roaring } // namespace duckdb diff --git a/src/duckdb/src/storage/compression/roaring/common.cpp b/src/duckdb/src/storage/compression/roaring/common.cpp index 80f7004de..10d5e3710 100644 --- a/src/duckdb/src/storage/compression/roaring/common.cpp +++ b/src/duckdb/src/storage/compression/roaring/common.cpp @@ -86,6 +86,7 @@ void SetInvalidRange(ValidityMask &result, idx_t start, idx_t end) { if (end <= start) { throw InternalException("SetInvalidRange called with end (%d) <= start (%d)", end, start); } + D_ASSERT(result.Capacity() >= end); result.EnsureWritable(); auto result_data = (validity_t *)result.GetData(); @@ -168,8 +169,8 @@ void SetInvalidRange(ValidityMask &result, idx_t start, idx_t end) { unique_ptr RoaringInitAnalyze(ColumnData &col_data, PhysicalType type) { // check if the storage version we are writing to supports roaring - auto &storage = col_data.GetStorageManager(); - if (storage.GetStorageVersion() < 4) { + const auto storage_version = col_data.GetStorageManager().GetStorageVersion(); + if (storage_version < 4 || (type == PhysicalType::BOOL && storage_version < 7)) { // compatibility mode with old versions - disable roaring return nullptr; } @@ -177,10 +178,10 @@ unique_ptr RoaringInitAnalyze(ColumnData &col_data, PhysicalType t auto state = make_uniq(info); return std::move(state); } - +template bool RoaringAnalyze(AnalyzeState &state, Vector &input, idx_t count) { auto &analyze_state = state.Cast(); - analyze_state.Analyze(input, count); + analyze_state.Analyze(input, count); return true; } @@ -198,9 +199,10 @@ unique_ptr RoaringInitCompression(ColumnDataCheckpointData &ch return make_uniq(checkpoint_data, std::move(state)); } +template void RoaringCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { auto &state = state_p.Cast(); - state.Compress(scan_vector, count); + state.Compress(scan_vector, count); } void RoaringFinalizeCompress(CompressionState &state_p) { @@ -208,7 +210,7 @@ void RoaringFinalizeCompress(CompressionState &state_p) { state.Finalize(); } -unique_ptr RoaringInitScan(ColumnSegment &segment) { +unique_ptr RoaringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } @@ -216,18 +218,55 @@ unique_ptr RoaringInitScan(ColumnSegment &segment) { //===--------------------------------------------------------------------===// // Scan base data //===--------------------------------------------------------------------===// +void ExtractValidityMaskToData(Vector &src, Vector &dst, idx_t offset, idx_t scan_count) { + // Get src's validity mask + auto &validity = FlatVector::Validity(src); + + auto write_ptr = dst.GetData() + offset; + if (validity.AllValid()) { + memset(write_ptr, 1, scan_count); // 1 is for valid + } else if (scan_count % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0) { + // "Bit-Unpack" src's validity_mask and put it in dst's data + BitpackingPrimitives::UnPackBuffer(dst.GetData() + offset, data_ptr_cast(validity.GetData()), + scan_count, 1); + } else { + // Because UnPackBuffer writes in batches of BITPACKING_ALGORITHM_GROUP_SIZE, we create a tmp_buffer first to + // prevent overflow in the case dst is smaller than the batch. + const auto tmp_buffer = + Vector(dst.GetType(), AlignValue(scan_count)); + BitpackingPrimitives::UnPackBuffer(tmp_buffer.GetData(), data_ptr_cast(validity.GetData()), scan_count, + 1); + memcpy(write_ptr, tmp_buffer.GetData(), scan_count); + } +} void RoaringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); scan_state.ScanPartial(start, result, result_offset, scan_count); } +void RoaringScanPartialBoolean(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, + idx_t result_offset) { + auto &scan_state = state.scan_state->Cast(); + auto start = state.GetPositionInSegment(); + Vector dummy(LogicalType::UBIGINT, false, false, scan_count); + scan_state.ScanPartial(start, dummy, 0, scan_count); + ExtractValidityMaskToData(dummy, result, result_offset, scan_count); +} void RoaringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { RoaringScanPartial(segment, state, scan_count, result, 0); } +void RoaringScanBoolean(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { + // Dummy vector, only created to capture the booleans in the validity mask, as the current RoaringScan populates the + // scanned data in the vector's validity mask + Vector dummy(LogicalType::UBIGINT, false, false, scan_count); + RoaringScan(segment, state, scan_count, dummy); + ExtractValidityMaskToData(dummy, result, 0, scan_count); +} + //===--------------------------------------------------------------------===// // Fetch //===--------------------------------------------------------------------===// @@ -240,6 +279,18 @@ void RoaringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_ scan_state.ScanInternal(container_state, 1, result, result_idx); } +void RoaringFetchRowBoolean(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + RoaringScanState scan_state(segment); + + idx_t internal_offset; + idx_t container_idx = scan_state.GetContainerIndex(static_cast(row_id), internal_offset); + auto &container_state = scan_state.LoadContainer(container_idx, internal_offset); + + Vector dummy(LogicalType::UBIGINT, false, false, 1); + scan_state.ScanInternal(container_state, 1, dummy, 0); + ExtractValidityMaskToData(dummy, result, result_idx, 1); +} void RoaringSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { // NO OP @@ -259,16 +310,42 @@ unique_ptr RoaringInitSegment(ColumnSegment &segment, bl // Get Function //===--------------------------------------------------------------------===// CompressionFunction GetCompressionFunction(PhysicalType data_type) { - return CompressionFunction(CompressionType::COMPRESSION_ROARING, data_type, roaring::RoaringInitAnalyze, - roaring::RoaringAnalyze, roaring::RoaringFinalAnalyze, roaring::RoaringInitCompression, - roaring::RoaringCompress, roaring::RoaringFinalizeCompress, roaring::RoaringInitScan, - roaring::RoaringScan, roaring::RoaringScanPartial, roaring::RoaringFetchRow, - roaring::RoaringSkip, roaring::RoaringInitSegment); + compression_analyze_t analyze = nullptr; + compression_compress_data_t compress = nullptr; + compression_scan_vector_t scan = nullptr; + compression_scan_partial_t scan_partial = nullptr; + compression_fetch_row_t fetch_row = nullptr; + + switch (data_type) { + case PhysicalType::BIT: { + analyze = roaring::RoaringAnalyze; + compress = roaring::RoaringCompress; + scan = roaring::RoaringScan; + scan_partial = roaring::RoaringScanPartial; + fetch_row = roaring::RoaringFetchRow; + break; + } + case PhysicalType::BOOL: { + analyze = roaring::RoaringAnalyze; + compress = roaring::RoaringCompress; + scan = roaring::RoaringScanBoolean; + scan_partial = roaring::RoaringScanPartialBoolean; + fetch_row = roaring::RoaringFetchRowBoolean; + break; + } + default: + throw InternalException("Roaring GetCompressionFunction, type %s not handled", EnumUtil::ToString(data_type)); + } + return CompressionFunction(CompressionType::COMPRESSION_ROARING, data_type, roaring::RoaringInitAnalyze, analyze, + roaring::RoaringFinalAnalyze, roaring::RoaringInitCompression, compress, + roaring::RoaringFinalizeCompress, roaring::RoaringInitScan, scan, scan_partial, + fetch_row, roaring::RoaringSkip, roaring::RoaringInitSegment); } CompressionFunction RoaringCompressionFun::GetFunction(PhysicalType type) { switch (type) { case PhysicalType::BIT: + case PhysicalType::BOOL: return GetCompressionFunction(type); default: throw InternalException("Unsupported type for Roaring"); @@ -278,6 +355,7 @@ CompressionFunction RoaringCompressionFun::GetFunction(PhysicalType type) { bool RoaringCompressionFun::TypeIsSupported(const PhysicalType physical_type) { switch (physical_type) { case PhysicalType::BIT: + case PhysicalType::BOOL: return true; default: return false; diff --git a/src/duckdb/src/storage/compression/roaring/compress.cpp b/src/duckdb/src/storage/compression/roaring/compress.cpp index fc2ba3625..ebe88c5ee 100644 --- a/src/duckdb/src/storage/compression/roaring/compress.cpp +++ b/src/duckdb/src/storage/compression/roaring/compress.cpp @@ -202,7 +202,7 @@ RoaringCompressState::RoaringCompressState(ColumnDataCheckpointData &checkpoint_ analyze_state(owned_analyze_state->Cast()), container_state(), container_metadata(analyze_state.container_metadata), checkpoint_data(checkpoint_data), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ROARING)) { - CreateEmptySegment(checkpoint_data.GetRowGroup().start); + CreateEmptySegment(); total_count = 0; InitializeContainer(); } @@ -212,34 +212,49 @@ idx_t RoaringCompressState::GetContainerIndex() { return index; } -idx_t RoaringCompressState::GetRemainingSpace() { - return static_cast(metadata_ptr - data_ptr); +idx_t RoaringCompressState::GetUsedDataSpace() { + return static_cast(data_ptr - (handle.Ptr() + sizeof(idx_t))); } -bool RoaringCompressState::CanStore(idx_t container_size, const ContainerMetadata &metadata) { - idx_t required_space = 0; - if (metadata.IsUncompressed()) { - // Account for the alignment we might need for this container - required_space += (AlignValue(reinterpret_cast(data_ptr))) - reinterpret_cast(data_ptr); - } - required_space += metadata.GetDataSizeInBytes(container_size); +idx_t RoaringCompressState::GetAvailableSpace() { + return static_cast(metadata_ptr - (handle.Ptr() + sizeof(idx_t))); +} +bool RoaringCompressState::CanStore(idx_t container_size_in_tuples, const ContainerMetadata &metadata) { + //! Required space for all the containers already stored + this additional container idx_t runs_count = metadata_collection.GetRunContainerCount(); idx_t arrays_count = metadata_collection.GetArrayAndBitsetContainerCount(); #ifdef DEBUG - idx_t current_size = metadata_collection.GetMetadataSize(runs_count + arrays_count, runs_count, arrays_count); - (void)current_size; - D_ASSERT(required_space + current_size <= GetRemainingSpace()); + { + //! Assert that whatever is already stored can actually fit on the segment + idx_t current_metadata_size = + metadata_collection.GetMetadataSize(runs_count + arrays_count, runs_count, arrays_count); + (void)current_metadata_size; + auto used_data_space = GetUsedDataSpace(); + used_data_space = AlignValue(used_data_space); + D_ASSERT(used_data_space + current_metadata_size <= GetAvailableSpace()); + } #endif + idx_t new_data_space = 0; + if (metadata.IsUncompressed()) { + //! Account for the alignment we might need for this container + //! Up to 7 bytes extra space required to align the data_ptr (see InitializeContainer) + new_data_space += (AlignValue(reinterpret_cast(data_ptr))) - reinterpret_cast(data_ptr); + } + //! Additional space required to store this new container + new_data_space += metadata.GetDataSizeInBytes(container_size_in_tuples); + if (metadata.IsRun()) { runs_count++; } else { + //! arrays_count contains both uncompressed and array container count arrays_count++; } - idx_t metadata_size = metadata_collection.GetMetadataSize(runs_count + arrays_count, runs_count, arrays_count); - required_space += metadata_size; + idx_t new_metadata_space = metadata_collection.GetMetadataSize(runs_count + arrays_count, runs_count, arrays_count); - if (required_space > GetRemainingSpace()) { + auto used_data_space = GetUsedDataSpace(); + auto required_data_space = AlignValue(used_data_space + new_data_space); + if (required_data_space + new_metadata_space > GetAvailableSpace()) { return false; } return true; @@ -257,9 +272,8 @@ void RoaringCompressState::InitializeContainer() { idx_t container_size = AlignValue( MinValue(analyze_state.total_count - container_state.appended_count, ROARING_CONTAINER_SIZE)); if (!CanStore(container_size, metadata)) { - idx_t row_start = current_segment->start + current_segment->count; FlushSegment(); - CreateEmptySegment(row_start); + CreateEmptySegment(); } // Override the pointer to write directly into the block @@ -278,12 +292,12 @@ void RoaringCompressState::InitializeContainer() { metadata_collection.AddMetadata(metadata); } -void RoaringCompressState::CreateEmptySegment(idx_t row_start) { +void RoaringCompressState::CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, info.GetBlockSize(), - info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); current_segment = std::move(compressed_segment); auto &buffer_manager = BufferManager::GetBufferManager(db); @@ -301,7 +315,7 @@ void RoaringCompressState::FlushSegment() { // +======================================+ // x: metadata_offset (to the "right" of it) - // d: data of the containers + // d: data of the containers (+ alignment) // m: metadata of the containers // This is after 'x' @@ -309,7 +323,7 @@ void RoaringCompressState::FlushSegment() { // Size of the 'd' part auto unaligned_data_size = NumericCast(data_ptr - base_ptr); - auto data_size = AlignValue(unaligned_data_size); + auto data_size = AlignValue(unaligned_data_size); data_ptr += data_size - unaligned_data_size; // Size of the 'm' part @@ -476,11 +490,32 @@ idx_t RoaringCompressState::Count(RoaringCompressState &state) { void RoaringCompressState::Flush(RoaringCompressState &state) { state.NextContainer(); } - -void RoaringCompressState::Compress(Vector &input, idx_t count) { +template <> +void RoaringCompressState::Compress(Vector &input, idx_t count) { auto &self = *this; RoaringStateAppender::AppendVector(self, input, count); } +template <> +void RoaringCompressState::Compress(Vector &input, idx_t count) { + auto &self = *this; + input.Flatten(count); + const bool *src = FlatVector::GetData(input); + + Vector bitpacked_vector(LogicalType::UBIGINT, count); + auto &bitpacked_vector_validity = FlatVector::Validity(bitpacked_vector); + bitpacked_vector_validity.EnsureWritable(); + const auto dst = data_ptr_cast(bitpacked_vector_validity.GetData()); + + const auto &validity = FlatVector::Validity(input); + // Bitpack the booleans, so they can be fed through the current compression code, with the same format as a validity + // mask. + if (validity.AllValid()) { + BitPackBooleans(dst, src, count, &validity, &this->current_segment->stats.statistics); + } else { + BitPackBooleans(dst, src, count, &validity, &this->current_segment->stats.statistics); + } + RoaringStateAppender::AppendVector(self, bitpacked_vector, count); +} } // namespace roaring diff --git a/src/duckdb/src/storage/compression/string_uncompressed.cpp b/src/duckdb/src/storage/compression/string_uncompressed.cpp index af3b826bf..b5b0eb931 100644 --- a/src/duckdb/src/storage/compression/string_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/string_uncompressed.cpp @@ -77,7 +77,8 @@ void UncompressedStringInitPrefetch(ColumnSegment &segment, PrefetchState &prefe } } -unique_ptr UncompressedStringStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr UncompressedStringStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); @@ -91,7 +92,7 @@ void UncompressedStringStorage::StringScanPartial(ColumnSegment &segment, Column Vector &result, idx_t result_offset) { // clear any previously locked buffers and get the primary buffer handle auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto dict_end = GetDictionaryEnd(segment, scan_state.handle); @@ -122,7 +123,7 @@ void UncompressedStringStorage::Select(ColumnSegment &segment, ColumnScanState & Vector &result, const SelectionVector &sel, idx_t sel_count) { // clear any previously locked buffers and get the primary buffer handle auto &scan_state = state.scan_state->Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); auto baseptr = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto dict_end = GetDictionaryEnd(segment, scan_state.handle); @@ -257,10 +258,11 @@ unique_ptr UncompressedStringStorage::DeserializeState(Deser return std::move(result); } -void UncompressedStringStorage::CleanupState(ColumnSegment &segment) { +void UncompressedStringStorage::VisitBlockIds(const ColumnSegment &segment, BlockIdVisitor &visitor) { auto &state = segment.GetSegmentState()->Cast(); - auto &block_manager = segment.GetBlockManager(); - state.Cleanup(block_manager); + for (auto &block_id : state.on_disk_blocks) { + visitor.Visit(block_id); + } } //===--------------------------------------------------------------------===// @@ -278,7 +280,7 @@ CompressionFunction StringUncompressed::GetFunction(PhysicalType data_type) { UncompressedStringStorage::StringInitSegment, UncompressedStringStorage::StringInitAppend, UncompressedStringStorage::StringAppend, UncompressedStringStorage::FinalizeAppend, nullptr, UncompressedStringStorage::SerializeState, UncompressedStringStorage::DeserializeState, - UncompressedStringStorage::CleanupState, UncompressedStringInitPrefetch, UncompressedStringStorage::Select); + UncompressedStringStorage::VisitBlockIds, UncompressedStringInitPrefetch, UncompressedStringStorage::Select); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/compression/validity_uncompressed.cpp b/src/duckdb/src/storage/compression/validity_uncompressed.cpp index 5a71b8974..66a57e582 100644 --- a/src/duckdb/src/storage/compression/validity_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/validity_uncompressed.cpp @@ -207,7 +207,7 @@ struct ValidityScanState : public SegmentScanState { block_id_t block_id; }; -unique_ptr ValidityInitScan(ColumnSegment &segment) { +unique_ptr ValidityInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); @@ -287,6 +287,13 @@ void ValidityUncompressed::UnalignedScan(data_ptr_t input, idx_t input_size, idx // otherwise the subsequent bitwise & will modify values outside of the range of values we want to alter input_mask |= ValidityUncompressed::UPPER_MASKS[shift_amount]; + if (pos == 0) { + // We also need to set the lower bits, which are to the left of the relevant bits (x), to 1 + // These are the bits that are "behind" this scan window, and should not affect this scan + auto non_relevant_mask = ValidityUncompressed::LOWER_MASKS[result_idx]; + input_mask |= non_relevant_mask; + } + // after this, we move to the next input_entry offset = ValidityMask::BITS_PER_VALUE - input_idx; input_entry++; @@ -390,7 +397,7 @@ void ValidityUncompressed::AlignedScan(data_ptr_t input, idx_t input_start, Vect void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); static_assert(sizeof(validity_t) == sizeof(uint64_t), "validity_t should be 64-bit"); auto &scan_state = state.scan_state->Cast(); @@ -403,7 +410,7 @@ void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t s void ValidityScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { result.Flatten(scan_count); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); if (start % ValidityMask::BITS_PER_VALUE == 0) { auto &scan_state = state.scan_state->Cast(); @@ -428,7 +435,7 @@ void ValiditySelect(ColumnSegment &segment, ColumnScanState &state, idx_t, Vecto auto &result_mask = FlatVector::Validity(result); auto input_data = reinterpret_cast(buffer_ptr); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); ValidityMask source_mask(input_data, segment.count); for (idx_t i = 0; i < sel_count; i++) { auto source_idx = start + sel.get_index(i); @@ -504,8 +511,8 @@ idx_t ValidityFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { return ((segment.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE) * ValidityMask::STANDARD_MASK_SIZE; } -void ValidityRevertAppend(ColumnSegment &segment, idx_t start_row) { - idx_t start_bit = start_row - segment.start; +void ValidityRevertAppend(ColumnSegment &segment, idx_t new_count) { + idx_t start_bit = new_count; auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto handle = buffer_manager.Pin(segment.block); diff --git a/src/duckdb/src/storage/compression/zstd.cpp b/src/duckdb/src/storage/compression/zstd.cpp index 408855284..dc9618598 100644 --- a/src/duckdb/src/storage/compression/zstd.cpp +++ b/src/duckdb/src/storage/compression/zstd.cpp @@ -81,7 +81,7 @@ struct ZSTDStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); @@ -98,7 +98,7 @@ struct ZSTDStorage { optional_ptr segment_state); static unique_ptr SerializeState(ColumnSegment &segment); static unique_ptr DeserializeState(Deserializer &deserializer); - static void CleanupState(ColumnSegment &segment); + static void VisitBlockIds(const ColumnSegment &segment, BlockIdVisitor &visitor); }; //===--------------------------------------------------------------------===// @@ -142,6 +142,11 @@ struct ZSTDAnalyzeState : public AnalyzeState { unique_ptr ZSTDStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { // check if the storage version we are writing to supports sztd auto &storage = col_data.GetStorageManager(); + auto &block_manager = col_data.GetBlockManager(); + if (block_manager.InMemory()) { + //! Can't use ZSTD in in-memory environment + return nullptr; + } if (storage.GetStorageVersion() < 4) { // compatibility mode with old versions - disable zstd return nullptr; @@ -231,7 +236,6 @@ class ZSTDCompressionState : public CompressionState { checkpoint_data(checkpoint_data), partial_block_manager(checkpoint_data.GetCheckpointState().GetPartialBlockManager()), function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ZSTD)) { - total_vector_count = GetVectorCount(analyze_state->count); total_segment_count = analyze_state->segment_count; vectors_per_segment = analyze_state->vectors_per_segment; @@ -249,6 +253,7 @@ class ZSTDCompressionState : public CompressionState { public: void ResetOutBuffer() { + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); out_buffer.dst = current_buffer_ptr; out_buffer.pos = 0; @@ -307,14 +312,8 @@ class ZSTDCompressionState : public CompressionState { throw InternalException("We are asking for a new segment, but somehow we're still writing vector data onto " "the initial (segment) page"); } - idx_t row_start; - if (segment) { - row_start = segment->start + segment->count; - FlushSegment(); - } else { - row_start = checkpoint_data.GetRowGroup().start; - } - CreateEmptySegment(row_start); + FlushSegment(); + CreateEmptySegment(); // Figure out how many vectors we are storing in this segment idx_t vectors_in_segment; @@ -347,6 +346,7 @@ class ZSTDCompressionState : public CompressionState { void InitializeVector() { D_ASSERT(!in_vector); if (vector_count + 1 >= total_vector_count) { + //! Last vector vector_size = analyze_state->count - (ZSTD_VECTOR_SIZE * vector_count); } else { vector_size = ZSTD_VECTOR_SIZE; @@ -355,6 +355,7 @@ class ZSTDCompressionState : public CompressionState { current_offset = UnsafeNumericCast( AlignValue(UnsafeNumericCast(current_offset))); current_buffer_ptr = current_buffer->Ptr() + current_offset; + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); compressed_size = 0; uncompressed_size = 0; @@ -413,20 +414,16 @@ class ZSTDCompressionState : public CompressionState { throw InvalidInputException("ZSTD Compression failed: %s", duckdb_zstd::ZSTD_getErrorName(compress_result)); } + D_ASSERT(GetCurrentOffset() <= GetWritableSpace(info)); if (compress_result == 0) { // Finished break; } - if (out_buffer.pos != out_buffer.size) { - throw InternalException("Expected ZSTD_compressStream2 to fully utilize the current buffer, but pos is " - "%d, while size is %d", - out_buffer.pos, out_buffer.size); - } NewPage(); } } - void AddString(const string_t &string) { + void AddStringInternal(const string_t &string) { if (!tuple_count) { InitializeVector(); } @@ -440,7 +437,10 @@ class ZSTDCompressionState : public CompressionState { // Reached the end of this vector FlushVector(); } + } + void AddString(const string_t &string) { + AddStringInternal(string); UncompressedStringStorage::UpdateStringStats(segment->stats, string); } @@ -455,7 +455,8 @@ class ZSTDCompressionState : public CompressionState { block_id_t FinalizePage() { auto &block_manager = partial_block_manager.GetBlockManager(); - auto new_id = block_manager.GetFreeBlockId(); + auto new_id = partial_block_manager.GetFreeBlockId(); + auto &state = segment->GetSegmentState()->Cast(); state.RegisterBlock(block_manager, new_id); @@ -521,11 +522,11 @@ class ZSTDCompressionState : public CompressionState { return res; } - void CreateEmptySegment(idx_t row_start) { + void CreateEmptySegment() { auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, function, type, row_start, - info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = + ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); segment = std::move(compressed_segment); auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); @@ -533,6 +534,9 @@ class ZSTDCompressionState : public CompressionState { } void FlushSegment() { + if (!segment) { + return; + } auto &state = checkpoint_data.GetCheckpointState(); idx_t segment_block_size; @@ -555,7 +559,8 @@ class ZSTDCompressionState : public CompressionState { } void AddNull() { - AddString(""); + segment->stats.statistics.SetHasNullFast(); + AddStringInternal(""); } public: @@ -691,7 +696,7 @@ struct ZSTDScanState : public SegmentScanState { explicit ZSTDScanState(ColumnSegment &segment) : state(segment.GetSegmentState()->Cast()), block_manager(segment.GetBlockManager()), buffer_manager(BufferManager::GetBufferManager(segment.db)), - segment_block_offset(segment.GetBlockOffset()) { + segment_block_offset(segment.GetBlockOffset()), segment(segment) { decompression_context = duckdb_zstd::ZSTD_createDCtx(); segment_handle = buffer_manager.Pin(segment.block); @@ -791,14 +796,23 @@ struct ZSTDScanState : public SegmentScanState { auto vector_size = metadata.count; + auto string_lengths_size = (sizeof(string_length_t) * vector_size); scan_state.string_lengths = reinterpret_cast(scan_state.current_buffer_ptr); - scan_state.current_buffer_ptr += (sizeof(string_length_t) * vector_size); + scan_state.current_buffer_ptr += string_lengths_size; // Update the in_buffer to point to the start of the compressed data frame idx_t current_offset = UnsafeNumericCast(scan_state.current_buffer_ptr - handle_start); scan_state.in_buffer.src = scan_state.current_buffer_ptr; scan_state.in_buffer.pos = 0; - scan_state.in_buffer.size = block_manager.GetBlockSize() - sizeof(block_id_t) - current_offset; + if (scan_state.metadata.block_offset + string_lengths_size + scan_state.metadata.compressed_size > + (segment.SegmentSize() - sizeof(block_id_t))) { + //! We know that the compressed size is too big to fit on the current page + scan_state.in_buffer.size = + MinValue(metadata.compressed_size, block_manager.GetBlockSize() - sizeof(block_id_t) - current_offset); + } else { + scan_state.in_buffer.size = + MinValue(metadata.compressed_size, block_manager.GetBlockSize() - current_offset); + } // Initialize the context for streaming decompression duckdb_zstd::ZSTD_DCtx_reset(decompression_context, duckdb_zstd::ZSTD_reset_session_only); @@ -832,7 +846,7 @@ struct ZSTDScanState : public SegmentScanState { scan_state.in_buffer.src = ptr; scan_state.in_buffer.pos = 0; - idx_t page_size = block_manager.GetBlockSize() - sizeof(block_id_t); + idx_t page_size = segment.SegmentSize() - sizeof(block_id_t); idx_t remaining_compressed_data = scan_state.metadata.compressed_size - scan_state.compressed_scan_count; scan_state.in_buffer.size = MinValue(page_size, remaining_compressed_data); } @@ -842,6 +856,7 @@ struct ZSTDScanState : public SegmentScanState { return; } + auto &in_buffer = scan_state.in_buffer; duckdb_zstd::ZSTD_outBuffer out_buffer; out_buffer.dst = destination; @@ -849,18 +864,25 @@ struct ZSTDScanState : public SegmentScanState { out_buffer.size = uncompressed_length; while (true) { - idx_t old_pos = scan_state.in_buffer.pos; + idx_t old_pos = in_buffer.pos; size_t res = duckdb_zstd::ZSTD_decompressStream( /* zds = */ decompression_context, /* output =*/&out_buffer, - /* input =*/&scan_state.in_buffer); - scan_state.compressed_scan_count += scan_state.in_buffer.pos - old_pos; + /* input =*/&in_buffer); + scan_state.compressed_scan_count += in_buffer.pos - old_pos; if (duckdb_zstd::ZSTD_isError(res)) { throw InvalidInputException("ZSTD Decompression failed: %s", duckdb_zstd::ZSTD_getErrorName(res)); } if (out_buffer.pos == out_buffer.size) { + //! Done decompressing the relevant portion + break; + } + if (!res) { + D_ASSERT(out_buffer.pos == out_buffer.size); + D_ASSERT(in_buffer.pos == in_buffer.size); break; } + D_ASSERT(in_buffer.pos == in_buffer.size); // Did not fully decompress, it needs a new page to read from LoadNextPageForVector(scan_state); } @@ -956,12 +978,13 @@ struct ZSTDScanState : public SegmentScanState { idx_t segment_count; //! The amount of tuples consumed idx_t scanned_count = 0; + ColumnSegment &segment; //! Buffer for skipping data AllocatedData skip_buffer; }; -unique_ptr ZSTDStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr ZSTDStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } @@ -972,7 +995,7 @@ unique_ptr ZSTDStorage::StringInitScan(ColumnSegment &segment) void ZSTDStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset) { auto &scan_state = state.scan_state->template Cast(); - auto start = segment.GetRelativeIndex(state.row_index); + auto start = state.GetPositionInSegment(); scan_state.ScanPartial(start, result, result_offset, scan_count); } @@ -1019,11 +1042,10 @@ unique_ptr ZSTDStorage::DeserializeState(Deserializer &deser return std::move(result); } -void ZSTDStorage::CleanupState(ColumnSegment &segment) { +void ZSTDStorage::VisitBlockIds(const ColumnSegment &segment, BlockIdVisitor &visitor) { auto &state = segment.GetSegmentState()->Cast(); - auto &block_manager = segment.GetBlockManager(); for (auto &block_id : state.on_disk_blocks) { - block_manager.MarkBlockAsModified(block_id); + visitor.Visit(block_id); } } @@ -1040,7 +1062,7 @@ CompressionFunction ZSTDFun::GetFunction(PhysicalType data_type) { zstd.init_segment = ZSTDStorage::StringInitSegment; zstd.serialize_state = ZSTDStorage::SerializeState; zstd.deserialize_state = ZSTDStorage::DeserializeState; - zstd.cleanup_state = ZSTDStorage::CleanupState; + zstd.visit_block_ids = ZSTDStorage::VisitBlockIds; return zstd; } diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 7d19449bb..2af933bc8 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -30,6 +30,7 @@ #include "duckdb/storage/table/update_state.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" namespace duckdb { @@ -57,12 +58,7 @@ DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_m this->row_groups = make_shared_ptr(info, io_manager, types, 0); if (data && data->row_group_count > 0) { this->row_groups->Initialize(*data); - if (!HasIndexes()) { - // if we don't have indexes, always append a new row group upon appending - // we can clean up this row group again when vacuuming - // since we don't yet support vacuum when there are indexes, we only do this when there are no indexes - row_groups->SetAppendRequiresNewRowGroup(); - } + row_groups->SetAppendRequiresNewRowGroup(); } else { this->row_groups->InitializeEmpty(); D_ASSERT(row_groups->GetTotalRows() == 0); @@ -146,7 +142,6 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_co DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint &constraint) : db(parent.db), info(parent.info), row_groups(parent.row_groups), version(DataTableVersion::MAIN_TABLE) { - // ALTER COLUMN to add a new constraint. // Clone the storage info vector or the table. @@ -173,7 +168,6 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) : db(parent.db), info(parent.info), version(DataTableVersion::MAIN_TABLE) { - auto &local_storage = LocalStorage::Get(context, db); // prevent any tuples from being added to the parent lock_guard lock(append_lock); @@ -242,18 +236,16 @@ TableIOManager &TableIOManager::Get(DataTable &table) { //===--------------------------------------------------------------------===// void DataTable::InitializeScan(ClientContext &context, DuckTransaction &transaction, TableScanState &state, const vector &column_ids, optional_ptr table_filters) { - state.checkpoint_lock = transaction.SharedLockTable(*info); auto &local_storage = LocalStorage::Get(transaction); state.Initialize(column_ids, context, table_filters); - row_groups->InitializeScan(state.table_state, column_ids, table_filters); + row_groups->InitializeScan(context, state.table_state, column_ids, table_filters); local_storage.InitializeScan(*this, state.local_state, table_filters); } void DataTable::InitializeScanWithOffset(DuckTransaction &transaction, TableScanState &state, const vector &column_ids, idx_t start_row, idx_t end_row) { - state.checkpoint_lock = transaction.SharedLockTable(*info); state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(state.table_state, column_ids, start_row, end_row); + row_groups->InitializeScanWithOffset(QueryContext(), state.table_state, column_ids, start_row, end_row); } idx_t DataTable::GetRowGroupSize() const { @@ -280,8 +272,6 @@ idx_t DataTable::MaxThreads(ClientContext &context) const { void DataTable::InitializeParallelScan(ClientContext &context, ParallelTableScanState &state) { auto &local_storage = LocalStorage::Get(context, db); - auto &transaction = DuckTransaction::Get(context, db); - state.checkpoint_lock = transaction.SharedLockTable(*info); row_groups->InitializeParallelScan(state.scan_state); local_storage.InitializeParallelScan(*this, state.local_state); @@ -369,10 +359,18 @@ void DataTable::VacuumIndexes() { } void DataTable::VerifyIndexBuffers() { - info->indexes.Scan([&](Index &index) { + info->VerifyIndexBuffers(); +} + +void DataTableInfo::VerifyIndexBuffers() { + indexes.ScanEntries([&](IndexEntry &entry) { + auto &index = *entry.index; if (index.IsBound()) { index.Cast().VerifyBuffers(); } + if (entry.deleted_rows_in_use) { + entry.deleted_rows_in_use->VerifyBuffers(); + } return false; }); } @@ -427,12 +425,10 @@ TableStorageInfo DataTable::GetStorageInfo() { //===--------------------------------------------------------------------===// void DataTable::Fetch(DuckTransaction &transaction, DataChunk &result, const vector &column_ids, const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { - auto lock = transaction.SharedLockTable(*info); row_groups->Fetch(transaction, result, column_ids, row_identifiers, fetch_count, state); } bool DataTable::CanFetch(DuckTransaction &transaction, const row_t row_id) { - auto lock = transaction.SharedLockTable(*info); return row_groups->CanFetch(transaction, row_id); } @@ -681,7 +677,7 @@ void DataTable::VerifyNewConstraint(LocalStorage &local_storage, DataTable &pare throw NotImplementedException("FIXME: ALTER COLUMN with such constraint is not supported yet"); } - parent.row_groups->VerifyNewConstraint(parent, constraint); + parent.row_groups->VerifyNewConstraint(local_storage.GetClientContext(), parent, constraint); local_storage.VerifyNewConstraint(parent, constraint); } @@ -768,7 +764,6 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, optional_ptr storage, optional_ptr manager) { - auto &table = constraint_state.table; if (table.HasGeneratedColumns()) { // Verify the generated columns against the inserted values. @@ -958,7 +953,6 @@ void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, Da void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection, const vector> &bound_constraints, optional_ptr> column_ids) { - LocalAppendState append_state; auto &storage = table.GetStorage(); storage.InitializeLocalAppend(append_state, table, context, bound_constraints); @@ -1014,15 +1008,38 @@ void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, Co storage.FinalizeLocalAppend(append_state); } -void DataTable::AppendLock(TableAppendState &state) { +void DataTable::AppendLock(DuckTransaction &transaction, TableAppendState &state) { state.append_lock = unique_lock(append_lock); if (!IsMainTable()) { throw TransactionException("Transaction conflict: attempting to insert into table \"%s\" but it has been %s by " "a different transaction", GetTableName(), TableModification()); } + state.table_lock = transaction.SharedLockTable(*info); state.row_start = NumericCast(row_groups->GetTotalRows()); state.current_row = state.row_start; + auto &transaction_manager = transaction.GetTransactionManager(); + auto active_checkpoint = transaction_manager.GetActiveCheckpoint(); + if (info->IsUnseenCheckpoint(active_checkpoint)) { + // there is a checkpoint active while we are appending + // in this case we cannot just blindly append to the last row group, because we need to checkpoint that + // always start a new row group in this case + row_groups->SetAppendRequiresNewRowGroup(); + } +} + +bool DataTableInfo::IsUnseenCheckpoint(transaction_t checkpoint_id) { + if (checkpoint_id == MAX_TRANSACTION_ID) { + // no active checkpoint + return false; + } + if (last_seen_checkpoint.IsValid() && last_seen_checkpoint.GetIndex() == checkpoint_id) { + // we have already seen this checkpoint + return false; + } + // we have not yet seen this checkpoint + last_seen_checkpoint = checkpoint_id; + return true; } void DataTable::InitializeAppend(DuckTransaction &transaction, TableAppendState &state) { @@ -1062,7 +1079,8 @@ void DataTable::ScanTableSegment(DuckTransaction &transaction, idx_t row_start, CreateIndexScanState state; InitializeScanWithOffset(transaction, state, column_ids, row_start, row_start + count); - auto row_start_aligned = state.table_state.row_group->start + state.table_state.vector_index * STANDARD_VECTOR_SIZE; + auto row_start_aligned = + state.table_state.row_group->GetRowStart() + state.table_state.vector_index * STANDARD_VECTOR_SIZE; idx_t current_row = row_start_aligned; while (current_row < end) { @@ -1141,6 +1159,7 @@ void DataTable::RevertAppendInternal(idx_t start_row) { void DataTable::RevertAppend(DuckTransaction &transaction, idx_t start_row, idx_t count) { lock_guard lock(append_lock); + auto table_lock = transaction.SharedLockTable(*info); // revert any appends to indexes if (!info->indexes.Empty()) { @@ -1197,7 +1216,7 @@ ErrorData DataTable::AppendToIndexes(TableIndexList &indexes, optional_ptr(); - unbound_index.BufferChunk(index_chunk, row_ids, mapped_column_ids); + unbound_index.BufferChunk(index_chunk, row_ids, mapped_column_ids, BufferedIndexReplay::INSERT_ENTRY); return false; } @@ -1245,7 +1264,7 @@ ErrorData DataTable::AppendToIndexes(optional_ptr delete_indexes index_append_mode); } -void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row_t row_start) { +void DataTable::RevertIndexAppend(TableAppendState &state, DataChunk &chunk, row_t row_start) { D_ASSERT(IsMainTable()); if (info->indexes.Empty()) { return; @@ -1255,24 +1274,22 @@ void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, row VectorOperations::GenerateSequence(row_identifiers, chunk.size(), row_start, 1); // now remove the entries from the indices - RemoveFromIndexes(state, chunk, row_identifiers); + RevertIndexAppend(state, chunk, row_identifiers); } -void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers) { +void DataTable::RevertIndexAppend(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers) { D_ASSERT(IsMainTable()); info->indexes.Scan([&](Index &index) { - if (!index.IsBound()) { - throw InternalException("Unbound index found in DataTable::RemoveFromIndexes"); - } - auto &bound_index = index.Cast(); - bound_index.Delete(chunk, row_identifiers); + auto &main_index = index.Cast(); + main_index.Delete(chunk, row_identifiers); return false; }); } -void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { +void DataTable::RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count, + IndexRemovalType removal_type) { D_ASSERT(IsMainTable()); - row_groups->RemoveFromIndexes(info->indexes, row_identifiers, count); + row_groups->RemoveFromIndexes(context, info->indexes, row_identifiers, count, removal_type); } //===--------------------------------------------------------------------===// @@ -1544,7 +1561,7 @@ void DataTable::Update(TableUpdateState &state, ClientContext &context, Vector & row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); row_ids_slice.Flatten(n_global_update); - row_groups->Update(transaction, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); + row_groups->Update(transaction, *this, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); } } @@ -1568,7 +1585,7 @@ void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, V updates.Flatten(); row_ids.Flatten(updates.size()); - row_groups->UpdateColumn(transaction, row_ids, column_path, updates); + row_groups->UpdateColumn(transaction, *this, row_ids, column_path, updates); } //===--------------------------------------------------------------------===// @@ -1593,10 +1610,6 @@ unique_ptr DataTable::GetSample() { //===--------------------------------------------------------------------===// // Checkpoint //===--------------------------------------------------------------------===// -unique_ptr DataTable::GetSharedCheckpointLock() { - return info->checkpoint_lock.GetSharedLock(); -} - unique_ptr DataTable::GetCheckpointLock() { return info->checkpoint_lock.GetExclusiveLock(); } @@ -1604,7 +1617,6 @@ unique_ptr DataTable::GetCheckpointLock() { void DataTable::Checkpoint(TableDataWriter &writer, Serializer &serializer) { // checkpoint each individual row group TableStatistics global_stats; - row_groups->CopyStats(global_stats); row_groups->Checkpoint(writer, global_stats); if (!HasIndexes()) { row_groups->SetAppendRequiresNewRowGroup(); @@ -1616,6 +1628,9 @@ void DataTable::Checkpoint(TableDataWriter &writer, Serializer &serializer) { // table pointer // index data writer.FinalizeTable(global_stats, *info, *row_groups, serializer); + if (writer.CanOverrideBaseStats()) { + row_groups->SetStats(global_stats); + } } void DataTable::CommitDropColumn(const idx_t column_index) { @@ -1649,9 +1664,8 @@ void DataTable::CommitDropTable() { //===--------------------------------------------------------------------===// // Column Segment Info //===--------------------------------------------------------------------===// -vector DataTable::GetColumnSegmentInfo() { - auto lock = GetSharedCheckpointLock(); - return row_groups->GetColumnSegmentInfo(); +vector DataTable::GetColumnSegmentInfo(const QueryContext &context) { + return row_groups->GetColumnSegmentInfo(context); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/external_file_cache.cpp b/src/duckdb/src/storage/external_file_cache.cpp index bcd5730f0..304116c7f 100644 --- a/src/duckdb/src/storage/external_file_cache.cpp +++ b/src/duckdb/src/storage/external_file_cache.cpp @@ -57,7 +57,8 @@ void ExternalFileCache::CachedFileRange::VerifyCheckSum() { #endif } -ExternalFileCache::CachedFile::CachedFile(string path_p) : path(std::move(path_p)) { +ExternalFileCache::CachedFile::CachedFile(string path_p) + : path(std::move(path_p)), file_size(0), last_modified(0), can_seek(false), on_disk_file(false) { } void ExternalFileCache::CachedFile::Verify(const unique_ptr &guard) const { diff --git a/src/duckdb/src/storage/index.cpp b/src/duckdb/src/storage/index.cpp index ca136d631..ab5c5b6ed 100644 --- a/src/duckdb/src/storage/index.cpp +++ b/src/duckdb/src/storage/index.cpp @@ -7,10 +7,6 @@ namespace duckdb { Index::Index(const vector &column_ids, TableIOManager &table_io_manager, AttachedDatabase &db) : column_ids(column_ids), table_io_manager(table_io_manager), db(db) { - - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("indexes are not supported on big endian architectures"); - } // create the column id set column_id_set.insert(column_ids.begin(), column_ids.end()); } diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index e3cbb8f3b..e8dba4937 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -16,12 +16,11 @@ namespace duckdb { LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) - : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(context, table), - merged_storage(false) { - + : context(context), table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), + optimistic_writer(context, table), merged_storage(false) { auto types = table.GetTypes(); auto data_table_info = table.GetDataTableInfo(); - row_groups = OptimisticDataWriter::CreateCollection(table, types); + row_groups = optimistic_writer.CreateCollection(table, types, OptimisticWritePartialManagers::GLOBAL); auto &collection = *row_groups->collection; collection.InitializeEmpty(); @@ -63,10 +62,9 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_data_table, LocalTableStorage &parent, const idx_t alter_column_index, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) - : table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), deleted_rows(parent.deleted_rows), - optimistic_collections(std::move(parent.optimistic_collections)), + : context(context), table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), + deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_data_table, parent.optimistic_writer), merged_storage(parent.merged_storage) { - // Alter the column type. auto &parent_collection = *parent.row_groups->collection; auto new_collection = @@ -83,7 +81,6 @@ LocalTableStorage::LocalTableStorage(DataTable &new_data_table, LocalTableStorag : table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_data_table, parent.optimistic_writer), merged_storage(parent.merged_storage) { - // Remove the column from the previous table storage. auto &parent_collection = *parent.row_groups->collection; auto new_collection = parent_collection.RemoveColumn(drop_column_index); @@ -99,7 +96,6 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_dt, : table_ref(new_dt), allocator(Allocator::Get(new_dt.db)), deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_dt, parent.optimistic_writer), merged_storage(parent.merged_storage) { - auto &parent_collection = *parent.row_groups->collection; auto new_collection = parent_collection.AddColumn(context, new_column, default_executor); row_groups = std::move(parent.row_groups); @@ -115,7 +111,7 @@ void LocalTableStorage::InitializeScan(CollectionScanState &state, optional_ptr< if (collection.GetTotalRows() == 0) { throw InternalException("No rows in LocalTableStorage row group for scan"); } - collection.InitializeScan(state, state.GetColumnIds(), table_filters.get()); + collection.InitializeScan(context, state, state.GetColumnIds(), table_filters.get()); } idx_t LocalTableStorage::EstimatedSize() { @@ -164,12 +160,25 @@ void LocalTableStorage::FlushBlocks() { ErrorData LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGroupCollection &source, TableIndexList &index_list, const vector &table_types, row_t &start_row) { - // In this function, we only care about scanning the indexed columns of a table. + // mapped_column_ids contains the physical column indices of each Indexed column in the table. + // This mapping is used to retrieve the physical column index for the corresponding vector of an index chunk scan. + // For example, if we are processing data for index_chunk.data[i], we can retrieve the physical column index + // by getting the value at mapped_column_ids[i]. + // An important note is that the index_chunk orderings are created in accordance with this mapping, not the other + // way around. (Check the scan code below, where the mapped_column_ids is passed as a parameter to the scan. + // The index_chunk inside of that lambda is ordered according to the mapping that is a parameter to the scan). + + // mapped_column_ids is used in two places: + // 1) To create the physical table chunk in this function. + // 2) If we are in an unbound state (i.e., WAL replay is happening right now), this mapping and the index_chunk + // are buffered in unbound_index. However, there can also be buffered deletes happening, so it is important + // to maintain a canonical representation of the mapping, which is just sorting. auto indexed_columns = index_list.GetRequiredColumns(); vector mapped_column_ids; for (auto &col : indexed_columns) { mapped_column_ids.emplace_back(col); } + std::sort(mapped_column_ids.begin(), mapped_column_ids.end()); // However, because the bound expressions of the indexes (and their bound // column references) are in relation to ALL table columns, we create an @@ -178,6 +187,7 @@ ErrorData LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGr DataChunk table_chunk; table_chunk.InitializeEmpty(table_types); + // index_chunk scans are created here in the mapped_column_ids ordering (see note above). ErrorData error; source.Scan(transaction, mapped_column_ids, [&](DataChunk &index_chunk) -> bool { D_ASSERT(index_chunk.ColumnCount() == mapped_column_ids.size()); @@ -205,7 +215,6 @@ void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppen bool append_to_table) { // In this function, we might scan all table columns, // as we might also append to the table itself (append_to_table). - auto &table = table_ref.get(); if (append_to_table) { table.InitializeAppend(transaction, append_state); @@ -259,7 +268,7 @@ void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppen collection.Scan(transaction, [&](DataChunk &chunk) -> bool { // Remove the chunk. try { - table.RemoveFromIndexes(append_state, chunk, current_row); + table.RevertIndexAppend(append_state, chunk, current_row); } catch (std::exception &ex) { // LCOV_EXCL_START error = ErrorData(ex); return false; @@ -564,7 +573,8 @@ idx_t LocalStorage::Delete(DataTable &table, Vector &row_ids, idx_t count) { // delete from unique indices (if any) if (!storage->append_indexes.Empty()) { - storage->GetCollection().RemoveFromIndexes(storage->append_indexes, row_ids, count); + storage->GetCollection().RemoveFromIndexes(context, storage->append_indexes, row_ids, count, + IndexRemovalType::MAIN_INDEX_ONLY); } auto ids = FlatVector::GetData(row_ids); @@ -580,7 +590,7 @@ void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector(row_ids); - storage->GetCollection().Update(TransactionData(0, 0), ids, column_ids, updates); + storage->GetCollection().Update(TransactionData(0, 0), table, ids, column_ids, updates); } void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ptr commit_state) { @@ -598,7 +608,7 @@ void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ const auto row_group_size = storage.GetCollection().GetRowGroupSize(); TableAppendState append_state; - table.AppendLock(append_state); + table.AppendLock(transaction, append_state); transaction.PushAppend(table, NumericCast(append_state.row_start), append_count); if ((append_state.row_start == 0 || storage.GetCollection().GetTotalRows() >= row_group_size) && storage.deleted_rows == 0) { @@ -752,7 +762,7 @@ void LocalStorage::VerifyNewConstraint(DataTable &parent, const BoundConstraint if (!storage) { return; } - storage->GetCollection().VerifyNewConstraint(parent, constraint); + storage->GetCollection().VerifyNewConstraint(context, parent, constraint); } } // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp index 8674f742d..91db5f75e 100644 --- a/src/duckdb/src/storage/metadata/metadata_manager.cpp +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -99,12 +99,16 @@ MetadataHandle MetadataManager::Pin(const MetadataPointer &pointer) { return Pin(QueryContext(), pointer); } -MetadataHandle MetadataManager::Pin(QueryContext context, const MetadataPointer &pointer) { +MetadataHandle MetadataManager::Pin(const QueryContext &context, const MetadataPointer &pointer) { D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); shared_ptr block_handle; { lock_guard guard(block_lock); - auto &block = blocks[UnsafeNumericCast(pointer.block_index)]; + auto entry = blocks.find(UnsafeNumericCast(pointer.block_index)); + if (entry == blocks.end()) { + throw InternalException("Trying to pin block %llu - but the block did not exist", pointer.block_index); + } + auto &block = entry->second; #ifdef DEBUG for (auto &free_block : block.free_blocks) { if (free_block == pointer.index) { @@ -272,15 +276,18 @@ void MetadataManager::Flush() { } continue; } - auto handle = buffer_manager.Pin(block.block); + auto block_handle = block.block; + auto handle = buffer_manager.Pin(block_handle); // zero-initialize the few leftover bytes memset(handle.Ptr() + total_metadata_size, 0, block_manager.GetBlockSize() - total_metadata_size); D_ASSERT(kv.first == block.block_id); - if (block.block->BlockId() >= MAXIMUM_BLOCK) { - auto new_block = - block_manager.ConvertToPersistent(QueryContext(), kv.first, block.block, std::move(handle)); - + if (block_handle->BlockId() >= MAXIMUM_BLOCK) { // Convert the temporary block to a persistent block. + // we cannot use ConvertToPersistent as another thread might still be reading the block + // so we use the safe version of ConvertToPersistent + auto new_block = block_manager.ConvertToPersistent(QueryContext(), kv.first, std::move(block_handle), + std::move(handle), ConvertToPersistentMode::THREAD_SAFE); + guard.lock(); block.block = std::move(new_block); guard.unlock(); @@ -366,6 +373,7 @@ void MetadataBlock::FreeBlocksFromInteger(idx_t free_list) { } void MetadataManager::MarkBlocksAsModified() { + unique_lock guard(block_lock); // for any blocks that were modified in the last checkpoint - set them to free blocks currently for (auto &kv : modified_blocks) { auto block_id = kv.first; @@ -379,7 +387,10 @@ void MetadataManager::MarkBlocksAsModified() { if (new_free_blocks == NumericLimits::Maximum()) { // if new free_blocks is all blocks - mark entire block as modified blocks.erase(entry); + + guard.unlock(); block_manager.MarkBlockAsModified(block_id); + guard.lock(); } else { // set the new set of free blocks block.FreeBlocksFromInteger(new_free_blocks); @@ -414,6 +425,18 @@ void MetadataManager::ClearModifiedBlocks(const vector &pointe } } +bool MetadataManager::BlockHasBeenCleared(const MetaBlockPointer &pointer) { + unique_lock guard(block_lock); + auto block_id = pointer.GetBlockId(); + auto block_index = pointer.GetBlockIndex(); + auto entry = modified_blocks.find(block_id); + if (entry == modified_blocks.end()) { + throw InternalException("BlockHasBeenCleared - Block id %llu not found in modified_blocks", block_id); + } + auto &modified_list = entry->second; + return (modified_list & (1ULL << block_index)) == 0ULL; +} + vector MetadataManager::GetMetadataInfo() const { vector result; unique_lock guard(block_lock); @@ -446,7 +469,7 @@ block_id_t MetadataManager::PeekNextBlockId() const { } block_id_t MetadataManager::GetNextBlockId() const { - return block_manager.GetFreeBlockId(); + return block_manager.GetFreeBlockIdForCheckpoint(); } } // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp index 06c2b1c1b..342833448 100644 --- a/src/duckdb/src/storage/metadata/metadata_reader.cpp +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -4,11 +4,8 @@ namespace duckdb { MetadataReader::MetadataReader(MetadataManager &manager, MetaBlockPointer pointer, optional_ptr> read_pointers_p, BlockReaderType type) - : manager(manager), type(type), next_pointer(FromDiskPointer(pointer)), has_next_block(true), - read_pointers(read_pointers_p), index(0), offset(0), next_offset(pointer.offset), capacity(0) { - if (read_pointers) { - read_pointers->push_back(pointer); - } + : manager(manager), type(type), next_pointer(pointer), has_next_block(true), read_pointers(read_pointers_p), + index(0), offset(0), next_offset(pointer.offset), capacity(0) { } MetadataReader::MetadataReader(MetadataManager &manager, BlockPointer pointer) @@ -59,11 +56,10 @@ MetaBlockPointer MetadataReader::GetMetaBlockPointer() { vector MetadataReader::GetRemainingBlocks(MetaBlockPointer last_block) { vector result; while (has_next_block) { - auto next_block_pointer = manager.GetDiskPointer(next_pointer, UnsafeNumericCast(next_offset)); - if (last_block.IsValid() && next_block_pointer.block_pointer == last_block.block_pointer) { + if (last_block.IsValid() && next_pointer.block_pointer == last_block.block_pointer) { break; } - result.push_back(next_block_pointer); + result.push_back(next_pointer); ReadNextBlock(); } return result; @@ -77,18 +73,18 @@ void MetadataReader::ReadNextBlock(QueryContext context) { if (!has_next_block) { throw IOException("No more data remaining in MetadataReader"); } - block = manager.Pin(context, next_pointer); - index = next_pointer.index; + if (read_pointers) { + read_pointers->push_back(next_pointer); + } + auto next_disk_pointer = FromDiskPointer(next_pointer); + block = manager.Pin(context, next_disk_pointer); + index = next_disk_pointer.index; idx_t next_block = Load(BasePtr()); if (next_block == idx_t(-1)) { has_next_block = false; } else { - next_pointer = FromDiskPointer(MetaBlockPointer(next_block, 0)); - MetaBlockPointer next_block_pointer(next_block, 0); - if (read_pointers) { - read_pointers->push_back(next_block_pointer); - } + next_pointer = MetaBlockPointer(next_block, 0); } if (next_offset < sizeof(block_id_t)) { next_offset = sizeof(block_id_t); diff --git a/src/duckdb/src/storage/metadata/metadata_writer.cpp b/src/duckdb/src/storage/metadata/metadata_writer.cpp index 69d8ea87e..8e7138b7d 100644 --- a/src/duckdb/src/storage/metadata/metadata_writer.cpp +++ b/src/duckdb/src/storage/metadata/metadata_writer.cpp @@ -32,7 +32,7 @@ MetaBlockPointer MetadataWriter::GetMetaBlockPointer() { void MetadataWriter::SetWrittenPointers(optional_ptr> written_pointers_p) { written_pointers = written_pointers_p; - if (written_pointers && capacity > 0) { + if (written_pointers && capacity > 0 && offset < capacity) { written_pointers->push_back(manager.GetDiskPointer(current_pointer)); } } diff --git a/src/duckdb/src/storage/optimistic_data_writer.cpp b/src/duckdb/src/storage/optimistic_data_writer.cpp index 4f595223f..e3fe08fe1 100644 --- a/src/duckdb/src/storage/optimistic_data_writer.cpp +++ b/src/duckdb/src/storage/optimistic_data_writer.cpp @@ -6,6 +6,9 @@ namespace duckdb { +OptimisticWriteCollection::~OptimisticWriteCollection() { +} + OptimisticDataWriter::OptimisticDataWriter(ClientContext &context, DataTable &table) : context(context), table(table) { } @@ -28,14 +31,14 @@ bool OptimisticDataWriter::PrepareWrite() { // allocate the partial block-manager if none is allocated yet if (!partial_manager) { auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); - partial_manager = - make_uniq(QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE); + partial_manager = make_uniq(context, block_manager, PartialBlockType::APPEND_TO_TABLE); } return true; } unique_ptr OptimisticDataWriter::CreateCollection(DataTable &storage, - const vector &insert_types) { + const vector &insert_types, + OptimisticWritePartialManagers type) { auto table_info = storage.GetDataTableInfo(); auto &io_manager = TableIOManager::Get(storage); @@ -45,6 +48,13 @@ unique_ptr OptimisticDataWriter::CreateCollection(Dat auto result = make_uniq(); result->collection = std::move(row_groups); + if (type == OptimisticWritePartialManagers::PER_COLUMN) { + for (idx_t i = 0; i < insert_types.size(); i++) { + auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); + result->partial_block_managers.push_back(make_uniq( + QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE)); + } + } return result; } @@ -58,11 +68,14 @@ void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_group auto unflushed_row_groups = row_groups.complete_row_groups - row_groups.last_flushed; if (unflushed_row_groups >= DBConfig::GetSetting(context)) { // we have crossed our flush threshold - flush any unwritten row groups to disk - vector> to_flush; + vector> to_flush; + vector segment_indexes; for (idx_t i = row_groups.last_flushed; i < row_groups.complete_row_groups; i++) { - to_flush.push_back(*row_groups.collection->GetRowGroup(NumericCast(i))); + auto segment_index = NumericCast(i); + to_flush.push_back(*row_groups.collection->GetRowGroup(segment_index)); + segment_indexes.push_back(segment_index); } - FlushToDisk(to_flush); + FlushToDisk(row_groups, to_flush, segment_indexes); row_groups.last_flushed = row_groups.complete_row_groups; } } @@ -73,36 +86,56 @@ void OptimisticDataWriter::WriteLastRowGroup(OptimisticWriteCollection &row_grou return; } // flush the last batch of row groups - vector> to_flush; + vector> to_flush; + vector segment_indexes; for (idx_t i = row_groups.last_flushed; i < row_groups.complete_row_groups; i++) { - to_flush.push_back(*row_groups.collection->GetRowGroup(NumericCast(i))); + auto segment_index = NumericCast(i); + to_flush.push_back(*row_groups.collection->GetRowGroup(segment_index)); + segment_indexes.push_back(segment_index); } // add the last (incomplete) row group to_flush.push_back(*row_groups.collection->GetRowGroup(-1)); - FlushToDisk(to_flush); + segment_indexes.push_back(-1); + + FlushToDisk(row_groups, to_flush, segment_indexes); + + for (auto &partial_manager : row_groups.partial_block_managers) { + Merge(partial_manager); + } + row_groups.partial_block_managers.clear(); } -void OptimisticDataWriter::FlushToDisk(const vector> &row_groups) { +void OptimisticDataWriter::FlushToDisk(OptimisticWriteCollection &collection, + const vector> &row_groups, + const vector &segment_indexes) { //! The set of column compression types (if any) vector compression_types; D_ASSERT(compression_types.empty()); for (auto &column : table.Columns()) { compression_types.push_back(column.CompressionType()); } - RowGroupWriteInfo info(*partial_manager, compression_types); - RowGroup::WriteToDisk(info, row_groups); + RowGroupWriteInfo info(*partial_manager, compression_types, collection.partial_block_managers); + auto result = RowGroup::WriteToDisk(info, row_groups); + // move new (checkpointed) row groups to the row group collection + for (idx_t i = 0; i < row_groups.size(); i++) { + collection.collection->SetRowGroup(segment_indexes[i], std::move(result[i].result_row_group)); + } } -void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { - if (!other.partial_manager) { +void OptimisticDataWriter::Merge(unique_ptr &other_manager) { + if (!other_manager) { return; } if (!partial_manager) { - partial_manager = std::move(other.partial_manager); + partial_manager = std::move(other_manager); return; } - partial_manager->Merge(*other.partial_manager); - other.partial_manager.reset(); + partial_manager->Merge(*other_manager); + other_manager.reset(); +} + +void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { + Merge(other.partial_manager); } void OptimisticDataWriter::FinalFlush() { diff --git a/src/duckdb/src/storage/partial_block_manager.cpp b/src/duckdb/src/storage/partial_block_manager.cpp index 27fe86cd3..26c3e6212 100644 --- a/src/duckdb/src/storage/partial_block_manager.cpp +++ b/src/duckdb/src/storage/partial_block_manager.cpp @@ -22,7 +22,6 @@ void PartialBlock::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, ui } void PartialBlock::FlushInternal(const idx_t free_space_left) { - // ensure that we do not leak any data if (free_space_left > 0 || !uninitialized_regions.empty()) { auto buffer_handle = block_manager.buffer_manager.Pin(block_handle); @@ -45,7 +44,6 @@ PartialBlockManager::PartialBlockManager(QueryContext context, BlockManager &blo uint32_t max_use_count) : context(context.GetClientContext()), block_manager(block_manager), partial_block_type(partial_block_type), max_use_count(max_use_count) { - if (max_partial_block_size_p.IsValid()) { max_partial_block_size = NumericCast(max_partial_block_size_p.GetIndex()); return; @@ -88,7 +86,7 @@ bool PartialBlockManager::HasBlockAllocation(uint32_t segment_size) { void PartialBlockManager::AllocateBlock(PartialBlockState &state, uint32_t segment_size) { D_ASSERT(segment_size <= block_manager.GetBlockSize()); if (partial_block_type == PartialBlockType::FULL_CHECKPOINT) { - state.block_id = block_manager.GetFreeBlockId(); + state.block_id = GetFreeBlockId(); } else { state.block_id = INVALID_BLOCK; } @@ -97,6 +95,14 @@ void PartialBlockManager::AllocateBlock(PartialBlockState &state, uint32_t segme state.block_use_count = 1; } +block_id_t PartialBlockManager::GetFreeBlockId() { + if (partial_block_type == PartialBlockType::FULL_CHECKPOINT) { + return block_manager.GetFreeBlockIdForCheckpoint(); + } else { + return block_manager.GetFreeBlockId(); + } +} + bool PartialBlockManager::GetPartialBlock(idx_t segment_size, unique_ptr &partial_block) { auto entry = partially_filled_blocks.lower_bound(segment_size); if (entry == partially_filled_blocks.end()) { diff --git a/src/duckdb/src/storage/serialization/serialize_nodes.cpp b/src/duckdb/src/storage/serialization/serialize_nodes.cpp index b87ba38ec..ac3959177 100644 --- a/src/duckdb/src/storage/serialization/serialize_nodes.cpp +++ b/src/duckdb/src/storage/serialization/serialize_nodes.cpp @@ -252,7 +252,7 @@ ColumnList ColumnList::Deserialize(Deserializer &deserializer) { void CommonTableExpressionInfo::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(100, "aliases", aliases); serializer.WritePropertyWithDefault>(101, "query", query); - serializer.WriteProperty(102, "materialized", materialized); + serializer.WriteProperty(102, "materialized", GetMaterializedForSerialization(serializer)); serializer.WritePropertyWithDefault>>(103, "key_targets", key_targets); } diff --git a/src/duckdb/src/storage/serialization/serialize_query_node.cpp b/src/duckdb/src/storage/serialization/serialize_query_node.cpp index 50ab535d2..25b167558 100644 --- a/src/duckdb/src/storage/serialization/serialize_query_node.cpp +++ b/src/duckdb/src/storage/serialization/serialize_query_node.cpp @@ -38,6 +38,9 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { } result->modifiers = std::move(modifiers); result->cte_map = std::move(cte_map); + if (type == QueryNodeType::CTE_NODE) { + result = std::move(result->Cast().child); + } return result; } diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp index 453961009..963d5646e 100644 --- a/src/duckdb/src/storage/serialization/serialize_types.cpp +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -42,6 +42,9 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) case ExtraTypeInfoType::GENERIC_TYPE_INFO: result = make_shared_ptr(type); break; + case ExtraTypeInfoType::GEO_TYPE_INFO: + result = GeoTypeInfo::Deserialize(deserializer); + break; case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO: result = IntegerLiteralTypeInfo::Deserialize(deserializer); break; @@ -136,6 +139,15 @@ unique_ptr ExtensionTypeInfo::Deserialize(Deserializer &deser return result; } +void GeoTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); +} + +shared_ptr GeoTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new GeoTypeInfo()); + return std::move(result); +} + void IntegerLiteralTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); serializer.WriteProperty(200, "constant_value", constant_value); diff --git a/src/duckdb/src/storage/single_file_block_manager.cpp b/src/duckdb/src/storage/single_file_block_manager.cpp index 6d22ff423..02c7c93aa 100644 --- a/src/duckdb/src/storage/single_file_block_manager.cpp +++ b/src/duckdb/src/storage/single_file_block_manager.cpp @@ -14,6 +14,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/block_allocator.hpp" #include "duckdb/storage/metadata/metadata_reader.hpp" #include "duckdb/storage/metadata/metadata_writer.hpp" #include "duckdb/storage/storage_info.hpp" @@ -66,13 +67,12 @@ void DeserializeEncryptionData(ReadStream &stream, data_t *dest, idx_t size) { void GenerateDBIdentifier(uint8_t *db_identifier) { memset(db_identifier, 0, MainHeader::DB_IDENTIFIER_LEN); - duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(db_identifier, - MainHeader::DB_IDENTIFIER_LEN); + RandomEngine engine; + engine.RandomData(db_identifier, MainHeader::DB_IDENTIFIER_LEN); } void EncryptCanary(MainHeader &main_header, const shared_ptr &encryption_state, const_data_ptr_t derived_key) { - uint8_t canary_buffer[MainHeader::CANARY_BYTE_SIZE]; // we zero-out the iv and the (not yet) encrypted canary @@ -248,7 +248,7 @@ DatabaseHeader DeserializeDatabaseHeader(const MainHeader &main_header, data_ptr SingleFileBlockManager::SingleFileBlockManager(AttachedDatabase &db_p, const string &path_p, const StorageManagerOptions &options) : BlockManager(BufferManager::GetBufferManager(db_p), options.block_alloc_size, options.block_header_size), - db(db_p), path(path_p), header_buffer(Allocator::Get(db_p), FileBufferType::MANAGED_BUFFER, + db(db_p), path(path_p), header_buffer(BlockAllocator::Get(db_p), FileBufferType::MANAGED_BUFFER, Storage::FILE_HEADER_SIZE - options.block_header_size.GetIndex(), options.block_header_size.GetIndex()), iteration_count(0), options(options) { @@ -362,6 +362,15 @@ void SingleFileBlockManager::CheckAndAddEncryptionKey(MainHeader &main_header) { void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { auto flags = GetFileFlags(true); + auto encryption_enabled = options.encryption_options.encryption_enabled; + if (encryption_enabled) { + if (!db.GetDatabase().GetEncryptionUtil()->SupportsEncryption() && !options.read_only) { + throw InvalidConfigurationException( + "The database was opened with encryption enabled, but DuckDB currently has a read-only crypto module " + "loaded. Please re-open using READONLY, or ensure httpfs is loaded using `LOAD httpfs`."); + } + } + // open the RDBMS handle auto &fs = FileSystem::Get(db); handle = fs.OpenFile(path, flags); @@ -376,7 +385,6 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { // Derive the encryption key and add it to the cache. // Not used for plain databases. data_t derived_key[MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH]; - auto encryption_enabled = options.encryption_options.encryption_enabled; // We need the unique database identifier, if the storage version is new enough. // If encryption is enabled, we also use it as the salt. @@ -487,6 +495,15 @@ void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) { if (main_header.IsEncrypted()) { if (options.encryption_options.encryption_enabled) { //! Encryption is set + + //! Check if our encryption module can write, if not, we should throw here + if (!db.GetDatabase().GetEncryptionUtil()->SupportsEncryption() && !options.read_only) { + throw InvalidConfigurationException( + "The database is encrypted, but DuckDB currently has a read-only crypto module loaded. Either " + "re-open the database using `ATTACH '..' (READONLY)`, or ensure httpfs is loaded using `LOAD " + "httpfs`."); + } + //! Check if the given key upon attach is correct // Derive the encryption key and add it to cache CheckAndAddEncryptionKey(main_header); @@ -506,6 +523,19 @@ void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) { path, EncryptionTypes::CipherToString(config_cipher), EncryptionTypes::CipherToString(stored_cipher)); } + + // This avoids the cipher from being downgrades by an attacker FIXME: we likely want to have a propervalidation + // of the cipher used instead of this trick to avoid downgrades + if (stored_cipher != EncryptionTypes::GCM) { + if (config_cipher == EncryptionTypes::INVALID) { + throw CatalogException( + "Cannot open encrypted database \"%s\" without explicitly specifying the " + "encryption cipher for security reasons. Please make sure you understand the security implications " + "and re-attach the database specifying the desired cipher.", + path); + } + } + // this is ugly, but the storage manager does not know the cipher type before db.GetStorageManager().SetCipher(stored_cipher); } @@ -620,7 +650,7 @@ void SingleFileBlockManager::ChecksumAndWrite(QueryContext context, FileBuffer & if (options.encryption_options.encryption_enabled && !skip_block_header) { auto key_id = options.encryption_options.derived_key_id; temp_buffer_manager = - make_uniq(Allocator::Get(db), block.GetBufferType(), block.Size(), GetBlockHeaderSize()); + make_uniq(BlockAllocator::Get(db), block.GetBufferType(), block.Size(), GetBlockHeaderSize()); EncryptionEngine::EncryptBlock(db, key_id, block, *temp_buffer_manager, delta); temp_buffer_manager->Write(context, *handle, location); } else { @@ -677,7 +707,6 @@ void SingleFileBlockManager::LoadFreeList(QueryContext context) { for (idx_t i = 0; i < free_list_count; i++) { auto block = reader.Read(context); free_list.insert(block); - newly_freed_list.insert(block); } auto multi_use_blocks_count = reader.Read(context); multi_use_blocks.clear(); @@ -694,21 +723,37 @@ bool SingleFileBlockManager::IsRootBlock(MetaBlockPointer root) { return root.block_pointer == meta_block; } -block_id_t SingleFileBlockManager::GetFreeBlockId() { - lock_guard lock(block_lock); +block_id_t SingleFileBlockManager::GetFreeBlockIdInternal(FreeBlockType type) { block_id_t block; - if (!free_list.empty()) { - // The free list is not empty, so we take its first element. - block = *free_list.begin(); - // erase the entry from the free list again - free_list.erase(free_list.begin()); - newly_freed_list.erase(block); - } else { - block = max_block++; + { + lock_guard lock(block_lock); + if (!free_list.empty()) { + // The free list is not empty, so we take its first element. + block = *free_list.begin(); + // erase the entry from the free list again + free_list.erase(free_list.begin()); + } else { + block = max_block++; + } + // add the entry to the list of newly used blocks + if (type == FreeBlockType::NEWLY_USED_BLOCK) { + newly_used_blocks.insert(block); + } + } + if (BlockIsRegistered(block)) { + throw InternalException("Free block %d is already registered", block); } return block; } +block_id_t SingleFileBlockManager::GetFreeBlockId() { + return GetFreeBlockIdInternal(FreeBlockType::NEWLY_USED_BLOCK); +} + +block_id_t SingleFileBlockManager::GetFreeBlockIdForCheckpoint() { + return GetFreeBlockIdInternal(FreeBlockType::CHECKPOINTED_BLOCK); +} + block_id_t SingleFileBlockManager::PeekFreeBlockId() { lock_guard lock(block_lock); if (!free_list.empty()) { @@ -718,16 +763,10 @@ block_id_t SingleFileBlockManager::PeekFreeBlockId() { } } -void SingleFileBlockManager::MarkBlockAsFree(block_id_t block_id) { +void SingleFileBlockManager::MarkBlockACheckpointed(block_id_t block_id) { lock_guard lock(block_lock); D_ASSERT(block_id >= 0); - D_ASSERT(block_id < max_block); - if (free_list.find(block_id) != free_list.end()) { - throw InternalException("MarkBlockAsFree called but block %llu was already freed!", block_id); - } - multi_use_blocks.erase(block_id); - free_list.insert(block_id); - newly_freed_list.insert(block_id); + newly_used_blocks.erase(block_id); } void SingleFileBlockManager::MarkBlockAsUsed(block_id_t block_id) { @@ -746,7 +785,6 @@ void SingleFileBlockManager::MarkBlockAsUsed(block_id_t block_id) { } else if (free_list.find(block_id) != free_list.end()) { // block is currently in the free list - erase free_list.erase(block_id); - newly_freed_list.erase(block_id); } else { // block is already in use - increase reference count IncreaseBlockReferenceCountInternal(block_id); @@ -771,10 +809,27 @@ void SingleFileBlockManager::MarkBlockAsModified(block_id_t block_id) { return; } // Check for multi-free - // TODO: Fix the bug that causes this assert to fire, then uncomment it. - // D_ASSERT(modified_blocks.find(block_id) == modified_blocks.end()); - D_ASSERT(free_list.find(block_id) == free_list.end()); - modified_blocks.insert(block_id); + if (modified_blocks.find(block_id) != modified_blocks.end()) { + throw InternalException("MarkBlockAsModified called with already modified block id %d", block_id); + } + if (free_list.find(block_id) != free_list.end()) { + throw InternalException("MarkBlockAsModified called with already freed block id %d", block_id); + } + auto newly_used_entry = newly_used_blocks.find(block_id); + if (newly_used_entry != newly_used_blocks.end()) { + // this block was newly used - and now we are labeling it as no longer being required + // we can directly add it back to the free list + newly_used_blocks.erase(block_id); + if (BlockIsRegistered(block_id)) { + free_blocks_in_use.insert(block_id); + } else { + free_list.insert(block_id); + } + } else { + // this block was used in storage, we cannot directly re-use it + // add it to the modified blocks indicating it will be re-usable after the next checkpoint + modified_blocks.insert(block_id); + } } void SingleFileBlockManager::IncreaseBlockReferenceCountInternal(block_id_t block_id) { @@ -823,9 +878,15 @@ void SingleFileBlockManager::VerifyBlocks(const unordered_map } } } + for (auto &newly_used_block : newly_used_blocks) { + referenced_blocks.insert(newly_used_block); + } for (auto &free_block : free_list) { referenced_blocks.insert(free_block); } + for (auto &free_block : free_blocks_in_use) { + referenced_blocks.insert(free_block); + } if (referenced_blocks.size() != NumericCast(max_block)) { // not all blocks are accounted for string missing_blocks; @@ -837,9 +898,39 @@ void SingleFileBlockManager::VerifyBlocks(const unordered_map missing_blocks += to_string(i); } } + string free_list_str; + for (auto &block : free_list) { + if (!free_list_str.empty()) { + free_list_str += ", "; + } + free_list_str += to_string(block); + } + string block_usage_str; + for (auto &entry : block_usage_count) { + if (!block_usage_str.empty()) { + block_usage_str += ", "; + } + block_usage_str += to_string(entry.first); + } + string multi_use_blocks_str; + for (auto &entry : multi_use_blocks) { + if (!multi_use_blocks_str.empty()) { + multi_use_blocks_str += ", "; + } + multi_use_blocks_str += to_string(entry.first); + } + string newly_used_blocks_str; + for (auto &block : newly_used_blocks) { + if (!newly_used_blocks_str.empty()) { + newly_used_blocks_str += ", "; + } + newly_used_blocks_str += to_string(block); + } + throw InternalException( - "Blocks %s were neither present in the free list or in the block_usage_count (max block %lld)", - missing_blocks, max_block); + "Block verification failed - blocks \"%s\" were not found as being used OR marked as free\nMax block: " + "%d\nBlock usage: %s\nFree list: %s\nMulti-use blocks: %s\nNewly used blocks: %s", + missing_blocks, max_block, block_usage_str, free_list_str, multi_use_blocks_str, newly_used_blocks_str); } } @@ -892,7 +983,7 @@ unique_ptr SingleFileBlockManager::CreateBlock(block_id_t block_id, FileB if (source_buffer) { result = ConvertBlock(block_id, *source_buffer); } else { - result = make_uniq(Allocator::Get(db), block_id, *this); + result = make_uniq(BlockAllocator::Get(db), block_id, *this); } result->Initialize(options.debug_initialize); return result; @@ -963,6 +1054,8 @@ void SingleFileBlockManager::Write(QueryContext context, FileBuffer &buffer, blo void SingleFileBlockManager::Truncate() { BlockManager::Truncate(); + + lock_guard guard(block_lock); idx_t blocks_to_truncate = 0; // reverse iterate over the free-list for (auto entry = free_list.rbegin(); entry != free_list.rend(); entry++) { @@ -979,7 +1072,6 @@ void SingleFileBlockManager::Truncate() { } // truncate the file free_list.erase(free_list.lower_bound(max_block), free_list.end()); - newly_freed_list.erase(newly_freed_list.lower_bound(max_block), newly_freed_list.end()); handle->Truncate(NumericCast(BLOCK_START + NumericCast(max_block) * GetBlockAllocSize())); } @@ -1036,13 +1128,28 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he // add all modified blocks to the free list: they can now be written to again metadata_manager.MarkBlocksAsModified(); - lock_guard lock(block_lock); + unique_lock lock(block_lock); // set the iteration count header.iteration = ++iteration_count; + set all_free_blocks = free_list; + set fully_freed_blocks; for (auto &block : modified_blocks) { - free_list.insert(block); - newly_freed_list.insert(block); + all_free_blocks.insert(block); + if (!BlockIsRegistered(block)) { + // if the block is no longer registered it is not in use - so it can be re-used after this point + free_list.insert(block); + fully_freed_blocks.insert(block); + } else { + // if the block is still registered it is still in use - keep it in the free_blocks_in_use list + free_blocks_in_use.insert(block); + } + } + auto written_multi_use_blocks = multi_use_blocks; + // newly used blocks are still free blocks for this checkpoint - so add them to the free list that we write + for (auto &newly_used_block : newly_used_blocks) { + all_free_blocks.insert(newly_used_block); + written_multi_use_blocks.erase(newly_used_block); } modified_blocks.clear(); @@ -1056,12 +1163,12 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he auto ptr = writer.GetMetaBlockPointer(); header.free_list = ptr.block_pointer; - writer.Write(free_list.size()); - for (auto &block_id : free_list) { + writer.Write(all_free_blocks.size()); + for (auto &block_id : all_free_blocks) { writer.Write(block_id); } - writer.Write(multi_use_blocks.size()); - for (auto &entry : multi_use_blocks) { + writer.Write(written_multi_use_blocks.size()); + for (auto &entry : written_multi_use_blocks) { writer.Write(entry.first); writer.Write(entry.second); } @@ -1071,8 +1178,13 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he // no blocks in the free list header.free_list = DConstants::INVALID_INDEX; } + lock.unlock(); metadata_manager.Flush(); + + lock.lock(); header.block_count = NumericCast(max_block); + lock.unlock(); + header.serialization_compatibility = options.storage_version.GetIndex(); auto debug_checkpoint_abort = DBConfig::GetSetting(db.GetDatabase()); @@ -1107,31 +1219,48 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he active_header = 1 - active_header; //! Ensure the header write ends up on disk handle->Sync(); - // Release the free blocks to the filesystem. - TrimFreeBlocks(); + // Release the free fully freed blocks to the filesystem. + TrimFreeBlocks(fully_freed_blocks); } void SingleFileBlockManager::FileSync() { handle->Sync(); } -void SingleFileBlockManager::TrimFreeBlocks() { - if (DBConfig::Get(db).options.trim_free_blocks) { - for (auto itr = newly_freed_list.begin(); itr != newly_freed_list.end(); ++itr) { - block_id_t first = *itr; - block_id_t last = first; - // Find end of contiguous range. - for (++itr; itr != newly_freed_list.end() && (*itr == last + 1); ++itr) { - last = *itr; - } - // We are now one too far. - --itr; - // Trim the range. - handle->Trim(BLOCK_START + (NumericCast(first) * GetBlockAllocSize()), - NumericCast(last + 1 - first) * GetBlockAllocSize()); +void SingleFileBlockManager::UnregisterBlock(block_id_t id) { + // perform the actual unregistration + BlockManager::UnregisterBlock(id); + // check if it is part of the newly free list + lock_guard lock(block_lock); + auto entry = free_blocks_in_use.find(id); + if (entry != free_blocks_in_use.end()) { + // it is! move it to the regular free list so the block can be re-used + free_list.insert(id); + free_blocks_in_use.erase(entry); + } +} + +void SingleFileBlockManager::TrimFreeBlockRange(block_id_t start, block_id_t end) { + auto block_count = NumericCast(end + 1 - start); + handle->Trim(BLOCK_START + (NumericCast(start) * GetBlockAllocSize()), block_count * GetBlockAllocSize()); +} + +void SingleFileBlockManager::TrimFreeBlocks(const set &blocks) { + if (!DBConfig::Get(db).options.trim_free_blocks) { + return; + } + for (auto itr = blocks.begin(); itr != blocks.end(); ++itr) { + block_id_t first = *itr; + block_id_t last = first; + // Find end of contiguous range. + for (++itr; itr != blocks.end() && (*itr == last + 1); ++itr) { + last = *itr; } + // We are now one too far. + --itr; + // Trim the range. + TrimFreeBlockRange(first, last); } - newly_freed_list.clear(); } } // namespace duckdb diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index e15986e1c..a8062b67a 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -11,6 +11,7 @@ #include "duckdb/storage/storage_manager.hpp" #include "duckdb/storage/temporary_file_manager.hpp" #include "duckdb/storage/temporary_memory_manager.hpp" +#include "duckdb/storage/block_allocator.hpp" #include "duckdb/common/encryption_functions.hpp" #include "duckdb/main/settings.hpp" @@ -48,7 +49,7 @@ unique_ptr StandardBufferManager::ConstructManagedBuffer(idx_t size, result = make_uniq(*tmp, type, block_header_size); } else { // non re-usable buffer: allocate a new buffer - result = make_uniq(Allocator::Get(db), type, size, block_header_size); + result = make_uniq(BlockAllocator::Get(db), type, size, block_header_size); } result->Initialize(DBConfig::GetConfig(db).options.debug_initialize); return result; @@ -338,7 +339,7 @@ BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { return Pin(QueryContext(), handle); } -BufferHandle StandardBufferManager::Pin(QueryContext context, shared_ptr &handle) { +BufferHandle StandardBufferManager::Pin(const QueryContext &context, shared_ptr &handle) { // we need to be careful not to return the BufferHandle to this block while holding the BlockHandle's lock // as exiting this function's scope may cause the destructor of the BufferHandle to be called while holding the lock // the destructor calls Unpin, which grabs the BlockHandle's lock again, causing a deadlock @@ -409,15 +410,16 @@ void StandardBufferManager::AddToEvictionQueue(shared_ptr &handle) void StandardBufferManager::VerifyZeroReaders(BlockLock &lock, shared_ptr &handle) { #ifdef DUCKDB_DEBUG_DESTROY_BLOCKS unique_ptr replacement_buffer; - auto &allocator = Allocator::Get(db); + auto &block_allocator = BlockAllocator::Get(db); auto &buffer = handle->GetBuffer(lock); auto block_header_size = buffer->GetHeaderSize(); auto alloc_size = buffer->AllocSize() - block_header_size; if (handle->GetBufferType() == FileBufferType::BLOCK) { auto block = reinterpret_cast(buffer.get()); - replacement_buffer = make_uniq(allocator, block->id, alloc_size, block_header_size); + replacement_buffer = make_uniq(block_allocator, block->id, alloc_size, block_header_size); } else { - replacement_buffer = make_uniq(allocator, buffer->GetBufferType(), alloc_size, block_header_size); + replacement_buffer = + make_uniq(block_allocator, buffer->GetBufferType(), alloc_size, block_header_size); } memcpy(replacement_buffer->buffer, buffer->buffer, buffer->size); WriteGarbageIntoBuffer(lock, *handle); @@ -495,7 +497,6 @@ void StandardBufferManager::RequireTemporaryDirectory() { } void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, FileBuffer &buffer) { - // WriteTemporaryBuffer assumes that we never write a buffer below DEFAULT_BLOCK_ALLOC_SIZE. RequireTemporaryDirectory(); @@ -543,8 +544,10 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(QueryContext c BlockHandle &block, unique_ptr reusable_buffer) { D_ASSERT(!temporary_directory.path.empty()); - D_ASSERT(temporary_directory.handle.get()); auto id = block.BlockId(); + if (!temporary_directory.handle) { + throw InternalException("ReadTemporaryBuffer called but temporary directory has not been instantiated yet"); + } if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { // This is a block that was offloaded to a regular .tmp file, the file contains blocks of a fixed size return temporary_directory.handle->GetTempFile().ReadTemporaryBuffer(context, id, std::move(reusable_buffer)); @@ -642,6 +645,10 @@ bool StandardBufferManager::HasFilesInTemporaryDirectory() const { return found; } +BlockManager &StandardBufferManager::GetTemporaryBlockManager() { + return *temp_block_manager; +} + vector StandardBufferManager::GetTemporaryFiles() { vector result; if (temporary_directory.path.empty()) { diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp index 89ae9cb61..3ff390b17 100644 --- a/src/duckdb/src/storage/statistics/base_statistics.cpp +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/statistics/list_stats.hpp" #include "duckdb/storage/statistics/struct_stats.hpp" @@ -31,6 +32,9 @@ void BaseStatistics::Construct(BaseStatistics &stats, LogicalType type) { case StatisticsType::ARRAY_STATS: ArrayStats::Construct(stats); break; + case StatisticsType::VARIANT_STATS: + VariantStats::Construct(stats); + break; default: break; } @@ -62,6 +66,12 @@ StatisticsType BaseStatistics::GetStatsType(const LogicalType &type) { if (type.id() == LogicalTypeId::SQLNULL) { return StatisticsType::BASE_STATS; } + if (type.id() == LogicalTypeId::GEOMETRY) { + return StatisticsType::GEOMETRY_STATS; + } + if (type.id() == LogicalTypeId::VARIANT) { + return StatisticsType::VARIANT_STATS; + } switch (type.InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: @@ -103,7 +113,7 @@ void BaseStatistics::InitializeUnknown() { void BaseStatistics::InitializeEmpty() { has_null = false; - has_no_null = true; + has_no_null = false; } bool BaseStatistics::CanHaveNull() const { @@ -153,6 +163,12 @@ void BaseStatistics::Merge(const BaseStatistics &other) { case StatisticsType::ARRAY_STATS: ArrayStats::Merge(*this, other); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Merge(*this, other); + break; + case StatisticsType::VARIANT_STATS: + VariantStats::Merge(*this, other); + break; default: break; } @@ -174,6 +190,10 @@ BaseStatistics BaseStatistics::CreateUnknownType(LogicalType type) { return StructStats::CreateUnknown(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateUnknown(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateUnknown(std::move(type)); + case StatisticsType::VARIANT_STATS: + return VariantStats::CreateUnknown(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -191,6 +211,10 @@ BaseStatistics BaseStatistics::CreateEmptyType(LogicalType type) { return StructStats::CreateEmpty(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateEmpty(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateEmpty(std::move(type)); + case StatisticsType::VARIANT_STATS: + return VariantStats::CreateEmpty(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -219,8 +243,10 @@ BaseStatistics BaseStatistics::CreateEmpty(LogicalType type) { void BaseStatistics::Copy(const BaseStatistics &other) { D_ASSERT(GetType() == other.GetType()); CopyBase(other); + auto stats_type = GetStatsType(); + stats_union = other.stats_union; - switch (GetStatsType()) { + switch (stats_type) { case StatisticsType::LIST_STATS: ListStats::Copy(*this, other); break; @@ -230,6 +256,9 @@ void BaseStatistics::Copy(const BaseStatistics &other) { case StatisticsType::ARRAY_STATS: ArrayStats::Copy(*this, other); break; + case StatisticsType::VARIANT_STATS: + VariantStats::Copy(*this, other); + break; default: break; } @@ -278,6 +307,10 @@ void BaseStatistics::Set(StatsInfo info) { void BaseStatistics::SetHasNull() { has_null = true; + if (type.id() == LogicalTypeId::VARIANT) { + VariantStats::GetUnshreddedStats(*this).SetHasNull(); + return; + } if (type.InternalType() == PhysicalType::STRUCT) { for (idx_t c = 0; c < StructType::GetChildCount(type); c++) { StructStats::GetChildStats(*this, c).SetHasNull(); @@ -287,6 +320,10 @@ void BaseStatistics::SetHasNull() { void BaseStatistics::SetHasNoNull() { has_no_null = true; + if (type.id() == LogicalTypeId::VARIANT) { + VariantStats::GetUnshreddedStats(*this).SetHasNoNull(); + return; + } if (type.InternalType() == PhysicalType::STRUCT) { for (idx_t c = 0; c < StructType::GetChildCount(type); c++) { StructStats::GetChildStats(*this, c).SetHasNoNull(); @@ -294,7 +331,7 @@ void BaseStatistics::SetHasNoNull() { } } -void BaseStatistics::CombineValidity(BaseStatistics &left, BaseStatistics &right) { +void BaseStatistics::CombineValidity(const BaseStatistics &left, const BaseStatistics &right) { has_null = left.has_null || right.has_null; has_no_null = left.has_no_null || right.has_no_null; } @@ -329,6 +366,12 @@ void BaseStatistics::Serialize(Serializer &serializer) const { case StatisticsType::ARRAY_STATS: ArrayStats::Serialize(*this, serializer); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Serialize(*this, serializer); + break; + case StatisticsType::VARIANT_STATS: + VariantStats::Serialize(*this, serializer); + break; default: break; } @@ -367,6 +410,12 @@ BaseStatistics BaseStatistics::Deserialize(Deserializer &deserializer) { case StatisticsType::ARRAY_STATS: ArrayStats::Deserialize(obj, stats); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Deserialize(obj, stats); + break; + case StatisticsType::VARIANT_STATS: + VariantStats::Deserialize(obj, stats); + break; default: break; } @@ -397,6 +446,12 @@ string BaseStatistics::ToString() const { case StatisticsType::ARRAY_STATS: result = ArrayStats::ToString(*this) + result; break; + case StatisticsType::GEOMETRY_STATS: + result = GeometryStats::ToString(*this) + result; + break; + case StatisticsType::VARIANT_STATS: + result = VariantStats::ToString(*this) + result; + break; default: break; } @@ -421,6 +476,12 @@ void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t co case StatisticsType::ARRAY_STATS: ArrayStats::Verify(*this, vector, sel, count); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Verify(*this, vector, sel, count); + break; + case StatisticsType::VARIANT_STATS: + VariantStats::Verify(*this, vector, sel, count); + break; default: break; } @@ -505,6 +566,25 @@ BaseStatistics BaseStatistics::FromConstantType(const Value &input) { } return result; } + case StatisticsType::GEOMETRY_STATS: { + auto result = GeometryStats::CreateEmpty(input.type()); + if (!input.IsNull()) { + auto &string_value = StringValue::Get(input); + GeometryStats::Update(result, string_t(string_value)); + } + return result; + } + case StatisticsType::VARIANT_STATS: { + auto result = VariantStats::CreateEmpty(input.type()); + auto unshredded_type = VariantShredding::GetUnshreddedType(); + if (input.IsNull()) { + VariantStats::SetUnshreddedStats(result, FromConstant(Value(unshredded_type))); + } else { + VariantStats::SetUnshreddedStats( + result, FromConstant(Value::STRUCT(unshredded_type, StructValue::GetChildren(input)))); + } + return result; + } default: return BaseStatistics(input.type()); } diff --git a/src/duckdb/src/storage/statistics/geometry_stats.cpp b/src/duckdb/src/storage/statistics/geometry_stats.cpp new file mode 100644 index 000000000..91ebeaa5f --- /dev/null +++ b/src/duckdb/src/storage/statistics/geometry_stats.cpp @@ -0,0 +1,280 @@ +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +vector GeometryTypeSet::ToString(bool snake_case) const { + vector result; + for (auto d = 0; d < VERT_TYPES; d++) { + for (auto i = 0; i < PART_TYPES; i++) { + if (sets[d] & (1 << i)) { + string str; + switch (i) { + case 1: + str = snake_case ? "point" : "Point"; + break; + case 2: + str = snake_case ? "linestring" : "LineString"; + break; + case 3: + str = snake_case ? "polygon" : "Polygon"; + break; + case 4: + str = snake_case ? "multipoint" : "MultiPoint"; + break; + case 5: + str = snake_case ? "multilinestring" : "MultiLineString"; + break; + case 6: + str = snake_case ? "multipolygon" : "MultiPolygon"; + break; + case 7: + str = snake_case ? "geometrycollection" : "GeometryCollection"; + break; + default: + str = snake_case ? "unknown" : "Unknown"; + break; + } + switch (d) { + case 1: + str += snake_case ? "_z" : " Z"; + break; + case 2: + str += snake_case ? "_m" : " M"; + break; + case 3: + str += snake_case ? "_zm" : " ZM"; + break; + default: + break; + } + + result.push_back(str); + } + } + } + return result; +} + +BaseStatistics GeometryStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + GetDataUnsafe(result).SetUnknown(); + return result; +} + +BaseStatistics GeometryStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + GetDataUnsafe(result).SetEmpty(); + return result; +} + +void GeometryStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + const auto &data = GetDataUnsafe(stats); + + // Write extent + serializer.WriteObject(200, "extent", [&](Serializer &extent) { + extent.WriteProperty(101, "x_min", data.extent.x_min); + extent.WriteProperty(102, "x_max", data.extent.x_max); + extent.WriteProperty(103, "y_min", data.extent.y_min); + extent.WriteProperty(104, "y_max", data.extent.y_max); + extent.WriteProperty(105, "z_min", data.extent.z_min); + extent.WriteProperty(106, "z_max", data.extent.z_max); + extent.WriteProperty(107, "m_min", data.extent.m_min); + extent.WriteProperty(108, "m_max", data.extent.m_max); + }); + + // Write types + serializer.WriteObject(201, "types", [&](Serializer &types) { + types.WriteProperty(101, "types_xy", data.types.sets[0]); + types.WriteProperty(102, "types_xyz", data.types.sets[1]); + types.WriteProperty(103, "types_xym", data.types.sets[2]); + types.WriteProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +void GeometryStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &data = GetDataUnsafe(base); + + // Read extent + deserializer.ReadObject(200, "extent", [&](Deserializer &extent) { + extent.ReadProperty(101, "x_min", data.extent.x_min); + extent.ReadProperty(102, "x_max", data.extent.x_max); + extent.ReadProperty(103, "y_min", data.extent.y_min); + extent.ReadProperty(104, "y_max", data.extent.y_max); + extent.ReadProperty(105, "z_min", data.extent.z_min); + extent.ReadProperty(106, "z_max", data.extent.z_max); + extent.ReadProperty(107, "m_min", data.extent.m_min); + extent.ReadProperty(108, "m_max", data.extent.m_max); + }); + + // Read types + deserializer.ReadObject(201, "types", [&](Deserializer &types) { + types.ReadProperty(101, "types_xy", data.types.sets[0]); + types.ReadProperty(102, "types_xyz", data.types.sets[1]); + types.ReadProperty(103, "types_xym", data.types.sets[2]); + types.ReadProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +string GeometryStats::ToString(const BaseStatistics &stats) { + const auto &data = GetDataUnsafe(stats); + string result; + + result += "["; + result += StringUtil::Format("Extent: [X: [%f, %f], Y: [%f, %f], Z: [%f, %f], M: [%f, %f]", data.extent.x_min, + data.extent.x_max, data.extent.y_min, data.extent.y_max, data.extent.z_min, + data.extent.z_max, data.extent.m_min, data.extent.m_max); + result += StringUtil::Format("], Types: [%s]", StringUtil::Join(data.types.ToString(true), ", ")); + + result += "]"; + return result; +} + +void GeometryStats::Update(BaseStatistics &stats, const string_t &value) { + auto &data = GetDataUnsafe(stats); + data.Update(value); +} + +void GeometryStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + if (other.GetType().id() == LogicalTypeId::SQLNULL) { + return; + } + + auto &target = GetDataUnsafe(stats); + auto &source = GetDataUnsafe(other); + target.Merge(source); +} + +void GeometryStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + // TODO: Verify stats +} + +const GeometryStatsData &GeometryStats::GetDataUnsafe(const BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +GeometryStatsData &GeometryStats::GetDataUnsafe(BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +GeometryExtent &GeometryStats::GetExtent(BaseStatistics &stats) { + return GetDataUnsafe(stats).extent; +} + +GeometryTypeSet &GeometryStats::GetTypes(BaseStatistics &stats) { + return GetDataUnsafe(stats).types; +} + +const GeometryExtent &GeometryStats::GetExtent(const BaseStatistics &stats) { + return GetDataUnsafe(stats).extent; +} + +const GeometryTypeSet &GeometryStats::GetTypes(const BaseStatistics &stats) { + return GetDataUnsafe(stats).types; +} + +// Expression comparison pruning +static FilterPropagateResult CheckIntersectionFilter(const GeometryStatsData &data, const Value &constant) { + if (constant.IsNull() || constant.type().id() != LogicalTypeId::GEOMETRY) { + // Cannot prune against NULL + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + // This has been checked before and needs to be true for the checks below to be valid + D_ASSERT(data.extent.HasXY()); + + const auto &geom = StringValue::Get(constant); + auto extent = GeometryExtent::Empty(); + if (Geometry::GetExtent(string_t(geom), extent) == 0) { + // If the geometry is empty, the predicate will never match + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + // Check if the bounding boxes intersect + // If the bounding boxes do not intersect, the predicate will never match + if (!extent.IntersectsXY(data.extent)) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + // If the column is completely inside the bounds, the predicate will always match + if (extent.ContainsXY(data.extent)) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + + // We cannot prune, as this column may contain geometries that intersect + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, const unique_ptr &expr) { + if (expr->GetExpressionType() != ExpressionType::BOUND_FUNCTION) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + if (expr->return_type != LogicalType::BOOLEAN) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + const auto &func = expr->Cast(); + if (func.children.size() != 2) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + if (func.children[0]->return_type.id() != LogicalTypeId::GEOMETRY || + func.children[1]->return_type.id() != LogicalTypeId::GEOMETRY) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + // The set of geometry predicates that can be optimized using the bounding box + static constexpr const char *geometry_predicates[2] = {"&&", "st_intersects_extent"}; + + auto found = false; + for (const auto &name : geometry_predicates) { + if (StringUtil::CIEquals(func.function.name.c_str(), name)) { + found = true; + break; + } + } + if (!found) { + // Not a geometry predicate we can optimize + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto lhs_kind = func.children[0]->GetExpressionType(); + const auto rhs_kind = func.children[1]->GetExpressionType(); + const auto lhs_is_const = lhs_kind == ExpressionType::VALUE_CONSTANT && rhs_kind == ExpressionType::BOUND_REF; + const auto rhs_is_const = rhs_kind == ExpressionType::VALUE_CONSTANT && lhs_kind == ExpressionType::BOUND_REF; + + if (!stats.CanHaveNoNull()) { + // no non-null values are possible: always false + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + auto &data = GetDataUnsafe(stats); + + if (!data.extent.HasXY()) { + // If the extent is empty or unknown, we cannot prune + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + if (lhs_is_const) { + return CheckIntersectionFilter(data, func.children[0]->Cast().value); + } + if (rhs_is_const) { + return CheckIntersectionFilter(data, func.children[1]->Cast().value); + } + // Else, no constant argument + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/numeric_stats.cpp b/src/duckdb/src/storage/statistics/numeric_stats.cpp index 803c21f12..225524406 100644 --- a/src/duckdb/src/storage/statistics/numeric_stats.cpp +++ b/src/duckdb/src/storage/statistics/numeric_stats.cpp @@ -237,6 +237,7 @@ FilterPropagateResult NumericStats::CheckZonemap(const BaseStatistics &stats, Ex if (!NumericStats::HasMinMax(stats)) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } + D_ASSERT(stats.CanHaveNoNull()); switch (stats.GetType().InternalType()) { case PhysicalType::INT8: return CheckZonemapTemplated(stats, comparison_type, constants); diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp index e7d232692..90107ae72 100644 --- a/src/duckdb/src/storage/statistics/string_stats.cpp +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -170,6 +170,14 @@ void StringStats::Update(BaseStatistics &stats, const string_t &value) { } } +void StringStats::SetMin(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).min); +} + +void StringStats::SetMax(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).max); +} + void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; @@ -193,6 +201,7 @@ void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { FilterPropagateResult StringStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, array_ptr constants) { auto &string_data = StringStats::GetDataUnsafe(stats); + D_ASSERT(stats.CanHaveNoNull()); for (auto &constant_value : constants) { D_ASSERT(constant_value.type() == stats.GetType()); D_ASSERT(!constant_value.IsNull()); diff --git a/src/duckdb/src/storage/statistics/variant_stats.cpp b/src/duckdb/src/storage/statistics/variant_stats.cpp new file mode 100644 index 000000000..0590bef5d --- /dev/null +++ b/src/duckdb/src/storage/statistics/variant_stats.cpp @@ -0,0 +1,527 @@ +#include "duckdb/storage/statistics/variant_stats.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/variant.hpp" + +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" + +namespace duckdb { + +static void AssertVariant(const BaseStatistics &stats) { + if (DUCKDB_UNLIKELY(stats.GetStatsType() != StatisticsType::VARIANT_STATS)) { + throw InternalException( + "Calling a VariantStats method on BaseStatistics that are not of type VARIANT, but of type %s", + EnumUtil::ToString(stats.GetStatsType())); + } +} + +void VariantStats::Construct(BaseStatistics &stats) { + stats.child_stats = unsafe_unique_array(new BaseStatistics[2]); + GetDataUnsafe(stats).shredding_state = VariantStatsShreddingState::UNINITIALIZED; + CreateUnshreddedStats(stats); +} + +BaseStatistics VariantStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + //! Unknown - we have no clue what's in this + GetDataUnsafe(result).shredding_state = VariantStatsShreddingState::INCONSISTENT; + result.child_stats[0].Copy(BaseStatistics::CreateUnknown(VariantShredding::GetUnshreddedType())); + return result; +} + +BaseStatistics VariantStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + GetDataUnsafe(result).shredding_state = VariantStatsShreddingState::UNINITIALIZED; + result.child_stats[0].Copy(BaseStatistics::CreateEmpty(VariantShredding::GetUnshreddedType())); + return result; +} + +//===--------------------------------------------------------------------===// +// Unshredded Stats +//===--------------------------------------------------------------------===// + +void VariantStats::CreateUnshreddedStats(BaseStatistics &stats) { + BaseStatistics::Construct(stats.child_stats[0], VariantShredding::GetUnshreddedType()); +} + +const BaseStatistics &VariantStats::GetUnshreddedStats(const BaseStatistics &stats) { + AssertVariant(stats); + return stats.child_stats[0]; +} + +BaseStatistics &VariantStats::GetUnshreddedStats(BaseStatistics &stats) { + AssertVariant(stats); + return stats.child_stats[0]; +} + +void VariantStats::SetUnshreddedStats(BaseStatistics &stats, const BaseStatistics &new_stats) { + AssertVariant(stats); + stats.child_stats[0].Copy(new_stats); +} + +void VariantStats::SetUnshreddedStats(BaseStatistics &stats, unique_ptr new_stats) { + AssertVariant(stats); + if (!new_stats) { + CreateUnshreddedStats(stats); + } else { + SetUnshreddedStats(stats, *new_stats); + } +} + +void VariantStats::MarkAsNotShredded(BaseStatistics &stats) { + D_ASSERT(!IsShredded(stats)); + auto &data = GetDataUnsafe(stats); + //! All Variant stats start off as UNINITIALIZED, to support merging + //! This method marks the stats as being unshredded, so they produce INCONSISTENT when merged with SHREDDED stats + data.shredding_state = VariantStatsShreddingState::NOT_SHREDDED; +} + +//===--------------------------------------------------------------------===// +// Shredded Stats +//===--------------------------------------------------------------------===// + +static void AssertShreddedStats(const BaseStatistics &stats) { + if (stats.GetType().id() != LogicalTypeId::STRUCT) { + throw InternalException("Shredded stats should be of type STRUCT, not %s", + EnumUtil::ToString(stats.GetType().id())); + } + auto &struct_children = StructType::GetChildTypes(stats.GetType()); + if (struct_children.size() != 2) { + throw InternalException( + "Shredded stats need to consist of 2 children, 'untyped_value_index' and 'typed_value', not: %s", + stats.GetType().ToString()); + } + if (struct_children[0].second.id() != LogicalTypeId::UINTEGER) { + throw InternalException("Shredded stats 'untyped_value_index' should be of type UINTEGER, not %s", + EnumUtil::ToString(struct_children[0].second.id())); + } +} + +bool VariantShreddedStats::IsFullyShredded(const BaseStatistics &stats) { + AssertShreddedStats(stats); + + auto &untyped_value_index_stats = StructStats::GetChildStats(stats, 0); + auto &typed_value_stats = StructStats::GetChildStats(stats, 1); + + if (!typed_value_stats.CanHaveNull()) { + //! Fully shredded, no nulls + return true; + } + if (!untyped_value_index_stats.CanHaveNoNull()) { + //! In the event that this field is entirely missing from the parent OBJECT, both are NULL + return false; + } + if (!NumericStats::HasMin(untyped_value_index_stats) || !NumericStats::HasMax(untyped_value_index_stats)) { + //! Has no min/max values, essentially double-checking the CanHaveNoNull from above + return false; + } + auto min_value = NumericStats::GetMinUnsafe(untyped_value_index_stats); + auto max_value = NumericStats::GetMaxUnsafe(untyped_value_index_stats); + if (min_value != max_value) { + //! Not a constant + return false; + } + //! 0 is reserved for NULL Variant values + return min_value == 0; +} + +LogicalType ToStructuredType(const LogicalType &shredding) { + D_ASSERT(shredding.id() == LogicalTypeId::STRUCT); + auto &child_types = StructType::GetChildTypes(shredding); + D_ASSERT(child_types.size() == 2); + + auto &typed_value = child_types[1].second; + + if (typed_value.id() == LogicalTypeId::STRUCT) { + auto &struct_children = StructType::GetChildTypes(typed_value); + child_list_t structured_children; + for (auto &child : struct_children) { + structured_children.emplace_back(child.first, ToStructuredType(child.second)); + } + return LogicalType::STRUCT(structured_children); + } else if (typed_value.id() == LogicalTypeId::LIST) { + auto &child_type = ListType::GetChildType(typed_value); + return LogicalType::LIST(ToStructuredType(child_type)); + } else { + return typed_value; + } +} + +LogicalType VariantStats::GetShreddedStructuredType(const BaseStatistics &stats) { + D_ASSERT(IsShredded(stats)); + return ToStructuredType(GetShreddedStats(stats).GetType()); +} + +void VariantStats::CreateShreddedStats(BaseStatistics &stats, const LogicalType &shredded_type) { + BaseStatistics::Construct(stats.child_stats[1], shredded_type); + auto &data = GetDataUnsafe(stats); + data.shredding_state = VariantStatsShreddingState::SHREDDED; +} + +bool VariantStats::IsShredded(const BaseStatistics &stats) { + auto &data = GetDataUnsafe(stats); + return data.shredding_state == VariantStatsShreddingState::SHREDDED; +} + +BaseStatistics VariantStats::CreateShredded(const LogicalType &shredded_type) { + BaseStatistics result(LogicalType::VARIANT()); + result.InitializeEmpty(); + + CreateShreddedStats(result, shredded_type); + result.child_stats[0].Copy(BaseStatistics::CreateEmpty(VariantShredding::GetUnshreddedType())); + result.child_stats[1].Copy(BaseStatistics::CreateEmpty(shredded_type)); + return result; +} + +const BaseStatistics &VariantStats::GetShreddedStats(const BaseStatistics &stats) { + AssertVariant(stats); + D_ASSERT(IsShredded(stats)); + return stats.child_stats[1]; +} + +BaseStatistics &VariantStats::GetShreddedStats(BaseStatistics &stats) { + AssertVariant(stats); + D_ASSERT(IsShredded(stats)); + return stats.child_stats[1]; +} + +void VariantStats::SetShreddedStats(BaseStatistics &stats, const BaseStatistics &new_stats) { + auto &data = GetDataUnsafe(stats); + if (!IsShredded(stats)) { + BaseStatistics::Construct(stats.child_stats[1], new_stats.GetType()); + D_ASSERT(data.shredding_state != VariantStatsShreddingState::INCONSISTENT); + data.shredding_state = VariantStatsShreddingState::SHREDDED; + } + stats.child_stats[1].Copy(new_stats); +} + +void VariantStats::SetShreddedStats(BaseStatistics &stats, unique_ptr new_stats) { + AssertVariant(stats); + D_ASSERT(new_stats); + SetShreddedStats(stats, *new_stats); +} + +//===--------------------------------------------------------------------===// +// (De)Serialization +//===--------------------------------------------------------------------===// + +void VariantStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + auto &data = GetDataUnsafe(stats); + auto &unshredded_stats = VariantStats::GetUnshreddedStats(stats); + + serializer.WriteProperty(200, "shredding_state", data.shredding_state); + + serializer.WriteProperty(225, "unshredded_stats", unshredded_stats); + if (IsShredded(stats)) { + auto &shredded_stats = VariantStats::GetShreddedStats(stats); + serializer.WriteProperty(230, "shredded_type", shredded_stats.type); + serializer.WriteProperty(235, "shredded_stats", shredded_stats); + } +} + +void VariantStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &type = base.GetType(); + D_ASSERT(type.InternalType() == PhysicalType::STRUCT); + D_ASSERT(type.id() == LogicalTypeId::VARIANT); + auto &data = GetDataUnsafe(base); + + auto unshredded_type = VariantShredding::GetUnshreddedType(); + data.shredding_state = deserializer.ReadProperty(200, "shredding_state"); + + { + //! Read the 'unshredded_stats' child + deserializer.Set(unshredded_type); + auto stat = deserializer.ReadProperty(225, "unshredded_stats"); + base.child_stats[0].Copy(stat); + deserializer.Unset(); + } + + if (!IsShredded(base)) { + return; + } + //! Read the type of the 'shredded_stats' + auto shredded_type = deserializer.ReadProperty(230, "shredded_type"); + + { + //! Finally read the 'shredded_stats' themselves + deserializer.Set(shredded_type); + auto stat = deserializer.ReadProperty(235, "shredded_stats"); + if (base.child_stats[1].type.id() == LogicalTypeId::INVALID) { + base.child_stats[1] = BaseStatistics::CreateUnknown(shredded_type); + } + base.child_stats[1].Copy(stat); + deserializer.Unset(); + } +} + +static string ToStringInternal(const BaseStatistics &stats) { + string result; + result = StringUtil::Format("fully_shredded: %s", VariantShreddedStats::IsFullyShredded(stats) ? "true" : "false"); + + auto &typed_value = StructStats::GetChildStats(stats, 1); + auto type_id = typed_value.GetType().id(); + if (type_id == LogicalTypeId::LIST) { + result += ", child: "; + auto &child_stats = ListStats::GetChildStats(typed_value); + result += ToStringInternal(child_stats); + } else if (type_id == LogicalTypeId::STRUCT) { + result += ", children: {"; + auto &fields = StructType::GetChildTypes(typed_value.GetType()); + for (idx_t i = 0; i < fields.size(); i++) { + if (i) { + result += ", "; + } + auto &child_stats = StructStats::GetChildStats(typed_value, i); + result += StringUtil::Format("%s: %s", fields[i].first, ToStringInternal(child_stats)); + } + result += "}"; + } + return result; +} + +string VariantStats::ToString(const BaseStatistics &stats) { + string result; + bool is_shredded = IsShredded(stats); + auto &data = GetDataUnsafe(stats); + result = StringUtil::Format("shredding_state: %s", EnumUtil::ToString(data.shredding_state)); + if (is_shredded) { + result += ", shredding: {"; + result += StringUtil::Format("typed_value_type: %s, ", ToStructuredType(stats.child_stats[1].type).ToString()); + result += StringUtil::Format("stats: {%s}", ToStringInternal(stats.child_stats[1])); + result += "}"; + } + return result; +} + +static BaseStatistics WrapTypedValue(BaseStatistics &untyped_value_index, BaseStatistics &typed_value) { + BaseStatistics shredded = BaseStatistics::CreateEmpty(LogicalType::STRUCT( + {{"untyped_value_index", untyped_value_index.GetType()}, {"typed_value", typed_value.GetType()}})); + + StructStats::GetChildStats(shredded, 0).Copy(untyped_value_index); + StructStats::GetChildStats(shredded, 1).Copy(typed_value); + return shredded; +} + +bool VariantStats::MergeShredding(BaseStatistics &stats, const BaseStatistics &other, BaseStatistics &new_stats) { + //! shredded_type: + //! STRUCT(untyped_value_index UINTEGER, typed_value ) + + //! shredding, 1 of: + //! - + //! - + //! - [] + + D_ASSERT(stats.type.id() == LogicalTypeId::STRUCT); + D_ASSERT(other.type.id() == LogicalTypeId::STRUCT); + + auto &stats_children = StructType::GetChildTypes(stats.type); + auto &other_children = StructType::GetChildTypes(other.type); + D_ASSERT(stats_children.size() == 2); + D_ASSERT(other_children.size() == 2); + + auto &stats_typed_value_type = stats_children[1].second; + auto &other_typed_value_type = other_children[1].second; + + //! Merge the untyped_value_index stats + auto &untyped_value_index = StructStats::GetChildStats(stats, 0); + untyped_value_index.Merge(StructStats::GetChildStats(other, 0)); + + auto &stats_typed_value = StructStats::GetChildStats(stats, 1); + auto &other_typed_value = StructStats::GetChildStats(other, 1); + + if (stats_typed_value_type.id() == LogicalTypeId::STRUCT) { + if (stats_typed_value_type.id() != other_typed_value_type.id()) { + //! other is not an OBJECT, can't merge + return false; + } + auto &stats_object_children = StructType::GetChildTypes(stats_typed_value_type); + auto &other_object_children = StructType::GetChildTypes(other_typed_value_type); + + //! Map field name to index, for 'other' + case_insensitive_map_t key_to_index; + for (idx_t i = 0; i < other_object_children.size(); i++) { + auto &other_object_child = other_object_children[i]; + key_to_index.emplace(other_object_child.first, i); + } + + //! Attempt to merge all overlapping fields, only keep the fields that were able to be merged + child_list_t new_children; + vector new_child_stats; + + for (idx_t i = 0; i < stats_object_children.size(); i++) { + auto &stats_object_child = stats_object_children[i]; + auto other_it = key_to_index.find(stats_object_child.first); + if (other_it == key_to_index.end()) { + continue; + } + auto &other_object_child = other_object_children[other_it->second]; + if (other_object_child.second.id() != stats_object_child.second.id()) { + //! TODO: perhaps we can keep the field but demote the type to unshredded somehow? + //! Or even use MaxLogicalType and merge the stats into that ? + continue; + } + + auto &stats_child = StructStats::GetChildStats(stats_typed_value, i); + auto &other_child = StructStats::GetChildStats(other_typed_value, other_it->second); + BaseStatistics new_child; + if (!MergeShredding(stats_child, other_child, new_child)) { + continue; + } + new_children.emplace_back(stats_object_child.first, new_child.GetType()); + new_child_stats.emplace_back(std::move(new_child)); + } + if (new_children.empty()) { + //! No fields remaining, demote to unshredded + return false; + } + + //! Create new stats out of the remaining fields + auto new_object_type = LogicalType::STRUCT(std::move(new_children)); + auto new_typed_value = BaseStatistics::CreateEmpty(new_object_type); + for (idx_t i = 0; i < new_child_stats.size(); i++) { + StructStats::SetChildStats(new_typed_value, i, new_child_stats[i]); + } + new_typed_value.CombineValidity(stats_typed_value, other_typed_value); + new_stats = WrapTypedValue(untyped_value_index, new_typed_value); + return true; + } else if (stats_typed_value_type.id() == LogicalTypeId::LIST) { + if (stats_typed_value_type.id() != other_typed_value_type.id()) { + //! other is not an ARRAY, can't merge + return false; + } + auto &stats_child = ListStats::GetChildStats(stats_typed_value); + auto &other_child = ListStats::GetChildStats(other_typed_value); + + //! TODO: perhaps we can keep the LIST part of the stats, and only demote the child to unshredded? + BaseStatistics new_child_stats; + if (!MergeShredding(stats_child, other_child, new_child_stats)) { + return false; + } + auto new_typed_value = BaseStatistics::CreateEmpty(LogicalType::LIST(new_child_stats.type)); + new_typed_value.CombineValidity(stats_typed_value, other_typed_value); + ListStats::SetChildStats(new_typed_value, new_child_stats.ToUnique()); + new_stats = WrapTypedValue(untyped_value_index, new_typed_value); + return true; + } else { + D_ASSERT(!stats_typed_value_type.IsNested()); + if (stats_typed_value_type.id() != other_typed_value_type.id()) { + //! other is not the same type, can't merge + return false; + } + stats_typed_value.Merge(other_typed_value); + new_stats = std::move(stats); + return true; + } +} + +void VariantStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + + stats.child_stats[0].Merge(other.child_stats[0]); + auto &data = GetDataUnsafe(stats); + auto &other_data = GetDataUnsafe(other); + + const auto other_shredding_state = other_data.shredding_state; + const auto shredding_state = data.shredding_state; + + if (other_shredding_state == VariantStatsShreddingState::UNINITIALIZED) { + //! No need to merge + return; + } + + switch (shredding_state) { + case VariantStatsShreddingState::INCONSISTENT: { + //! INCONSISTENT + ANY -> INCONSISTENT + return; + } + case VariantStatsShreddingState::UNINITIALIZED: { + switch (other_shredding_state) { + case VariantStatsShreddingState::SHREDDED: + stats.child_stats[1] = BaseStatistics::CreateUnknown(other.child_stats[1].GetType()); + stats.child_stats[1].Copy(other.child_stats[1]); + break; + default: + break; + } + //! UNINITIALIZED + ANY -> ANY + data.shredding_state = other_shredding_state; + break; + } + case VariantStatsShreddingState::NOT_SHREDDED: { + if (other_shredding_state == VariantStatsShreddingState::NOT_SHREDDED) { + return; + } + //! NOT_SHREDDED + !NOT_SHREDDED -> INCONSISTENT + data.shredding_state = VariantStatsShreddingState::INCONSISTENT; + stats.child_stats[1].type = LogicalType::INVALID; + break; + } + case VariantStatsShreddingState::SHREDDED: { + switch (other_shredding_state) { + case VariantStatsShreddingState::SHREDDED: { + BaseStatistics merged_shredding_stats; + if (!MergeShredding(stats.child_stats[1], other.child_stats[1], merged_shredding_stats)) { + //! SHREDDED(T1) + SHREDDED(T2) -> INCONSISTENT + data.shredding_state = VariantStatsShreddingState::INCONSISTENT; + stats.child_stats[1].type = LogicalType::INVALID; + } else { + //! SHREDDED(T1) + SHREDDED(T1) -> SHREDDED + stats.child_stats[1] = BaseStatistics::CreateUnknown(merged_shredding_stats.GetType()); + stats.child_stats[1].Copy(merged_shredding_stats); + } + break; + } + default: + //! SHREDDED + !SHREDDED -> INCONSISTENT + data.shredding_state = VariantStatsShreddingState::INCONSISTENT; + stats.child_stats[1].type = LogicalType::INVALID; + break; + } + break; + } + } +} + +void VariantStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { + auto &other_data = VariantStats::GetDataUnsafe(other); + auto &data = VariantStats::GetDataUnsafe(stats); + (void)data; + + //! This is ensured by the CopyBase method of BaseStatistics + D_ASSERT(data.shredding_state == other_data.shredding_state); + stats.child_stats[0].Copy(other.child_stats[0]); + if (IsShredded(other)) { + stats.child_stats[1] = BaseStatistics::CreateUnknown(other.child_stats[1].GetType()); + stats.child_stats[1].Copy(other.child_stats[1]); + } else { + stats.child_stats[1].type = LogicalType::INVALID; + } +} + +void VariantStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + // TODO: Verify stats +} + +const VariantStatsData &VariantStats::GetDataUnsafe(const BaseStatistics &stats) { + AssertVariant(stats); + return stats.stats_union.variant_data; +} + +VariantStatsData &VariantStats::GetDataUnsafe(BaseStatistics &stats) { + AssertVariant(stats); + return stats.stats_union.variant_data; +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/storage_info.cpp b/src/duckdb/src/storage/storage_info.cpp index 616aa3039..c6ba18fd9 100644 --- a/src/duckdb/src/storage/storage_info.cpp +++ b/src/duckdb/src/storage/storage_info.cpp @@ -4,6 +4,10 @@ #include "duckdb/common/optional_idx.hpp" namespace duckdb { +constexpr idx_t Storage::MAX_ROW_GROUP_SIZE; +constexpr idx_t Storage::MAX_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::MIN_BLOCK_ALLOC_SIZE; +constexpr idx_t Storage::DEFAULT_BLOCK_HEADER_SIZE; const uint64_t VERSION_NUMBER = 64; const uint64_t VERSION_NUMBER_LOWER = 64; @@ -83,6 +87,9 @@ static const StorageVersionInfo storage_version_info[] = { {"v1.3.1", 66}, {"v1.3.2", 66}, {"v1.4.0", 67}, + {"v1.4.1", 67}, + {"v1.4.2", 67}, + {"v1.4.3", 67}, {"v1.5.0", 67}, {nullptr, 0} }; @@ -109,6 +116,9 @@ static const SerializationVersionInfo serialization_version_info[] = { {"v1.3.1", 5}, {"v1.3.2", 5}, {"v1.4.0", 6}, + {"v1.4.1", 6}, + {"v1.4.2", 6}, + {"v1.4.3", 6}, {"v1.5.0", 7}, {"latest", 7}, {nullptr, 0} diff --git a/src/duckdb/src/storage/storage_manager.cpp b/src/duckdb/src/storage/storage_manager.cpp index d82905452..b4e9b521a 100644 --- a/src/duckdb/src/storage/storage_manager.cpp +++ b/src/duckdb/src/storage/storage_manager.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/checkpoint_manager.hpp" #include "duckdb/storage/in_memory_block_manager.hpp" @@ -16,6 +17,7 @@ #include "duckdb/catalog/duck_catalog.hpp" #include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" #include "mbedtls_wrapper.hpp" namespace duckdb { @@ -78,9 +80,7 @@ void StorageOptions::Initialize(const unordered_map &options) { } StorageManager::StorageManager(AttachedDatabase &db, string path_p, const AttachOptions &options) - : db(db), path(std::move(path_p)), read_only(options.access_mode == AccessMode::READ_ONLY), - in_memory_change_size(0) { - + : db(db), path(std::move(path_p)), read_only(options.access_mode == AccessMode::READ_ONLY), wal_size(0) { if (path.empty()) { path = IN_MEMORY_PATH; return; @@ -110,7 +110,15 @@ ObjectCache &ObjectCache::GetObjectCache(ClientContext &context) { } idx_t StorageManager::GetWALSize() { - return InMemory() ? in_memory_change_size.load() : wal->GetWALSize(); + return wal_size; +} + +void StorageManager::AddWALSize(idx_t size) { + wal_size += size; +} + +void StorageManager::SetWALSize(idx_t size) { + wal_size = size; } optional_ptr StorageManager::GetWAL() { @@ -120,24 +128,115 @@ optional_ptr StorageManager::GetWAL() { return wal.get(); } -void StorageManager::ResetWAL() { - wal->Delete(); +bool StorageManager::HasWAL() const { + if (InMemory() || read_only || !load_complete) { + return false; + } + return true; +} + +bool StorageManager::WALStartCheckpoint(MetaBlockPointer meta_block, CheckpointOptions &options) { + lock_guard guard(wal_lock); + // while holding the WAL lock - get the last committed transaction from the transaction manager + // this is the commit we will be checkpointing on - everything in this commit will be written to the file + // any new commits made will be written to the next wal + auto &transaction_manager = db.GetTransactionManager().Cast(); + options.transaction_id = transaction_manager.GetNewCheckpointId(); + + DUCKDB_LOG(db.GetDatabase(), TransactionLogType, db, "Start Checkpoint", options.transaction_id); + if (!wal) { + return false; + } + // start the checkpoint process around the WAL + if (GetWALSize() == 0) { + // no WAL - we don't need to do anything here + return false; + } + // verify the main WAL is the active WAL currently + if (wal->GetPath() != wal_path) { + throw InternalException("Current WAL path %s does not match base WAL path %s in WALStartCheckpoint", + wal->GetPath(), wal_path); + } + // write to the main WAL that we have initiated a checkpoint + wal->WriteCheckpoint(meta_block); + wal->Flush(); + + // close the main WAL + wal.reset(); + + // replace the WAL with a new WAL (.checkpoint.wal) that transactions can write to while the checkpoint is happening + // we don't eagerly write to this WAL - we just instantiate it here so it can be written to + // if a checkpoint WAL already exists - delete it before proceeding + auto checkpoint_wal_path = GetCheckpointWALPath(); + auto &fs = FileSystem::Get(db); + fs.TryRemoveFile(checkpoint_wal_path); + + // the checkpoint WAL belongs to the NEXT checkpoint - when we are done we will overwrite the current WAL with it + // as such override the checkpoint iteration number to the next one + auto &single_file_block_manager = GetBlockManager().Cast(); + auto next_checkpoint_iteration = single_file_block_manager.GetCheckpointIteration() + 1; + wal = make_uniq(*this, checkpoint_wal_path, 0ULL, WALInitState::NO_WAL, next_checkpoint_iteration); + return true; +} + +void StorageManager::WALFinishCheckpoint() { + lock_guard guard(wal_lock); + D_ASSERT(wal.get()); + + // "wal" points to the checkpoint WAL + // first check if the checkpoint WAL has been written to + auto &fs = FileSystem::Get(db); + if (!wal->Initialized()) { + // the checkpoint WAL has not been written to + // this is the common scenario if there are no concurrent writes happening while checkpointing + // in this case we can just remove the main WAL and re-instantiate it to empty + fs.TryRemoveFile(wal_path); + + wal = make_uniq(*this, wal_path); + return; + } + + // we have had writes to the checkpoint WAL - we need to override our WAL with the checkpoint WAL + // first close the WAL writer + auto checkpoint_wal_path = wal->GetPath(); + wal.reset(); + + // move the secondary WAL over the main WAL + fs.MoveFile(checkpoint_wal_path, wal_path); + + // open what is now the main WAL again + wal = make_uniq(*this, wal_path); + wal->Initialize(); + + DUCKDB_LOG(db.GetDatabase(), TransactionLogType, db, "Finish Checkpoint"); +} + +unique_ptr> StorageManager::GetWALLock() { + return make_uniq>(wal_lock); } -string StorageManager::GetWALPath() const { +string StorageManager::GetWALPath(const string &suffix) { // we append the ".wal" **before** a question mark in case of GET parameters // but only if we are not in a windows long path (which starts with \\?\) std::size_t question_mark_pos = std::string::npos; if (!StringUtil::StartsWith(path, "\\\\?\\")) { question_mark_pos = path.find('?'); } - auto wal_path = path; + auto result = path; if (question_mark_pos != std::string::npos) { - wal_path.insert(question_mark_pos, ".wal"); + result.insert(question_mark_pos, suffix); } else { - wal_path += ".wal"; + result += suffix; } - return wal_path; + return result; +} + +string StorageManager::GetCheckpointWALPath() { + return GetWALPath(".checkpoint.wal"); +} + +string StorageManager::GetRecoveryWALPath() { + return GetWALPath(".recovery.wal"); } bool StorageManager::InMemory() const { @@ -148,6 +247,14 @@ bool StorageManager::InMemory() const { void StorageManager::Destroy() { } +inline void ClearUserKey(shared_ptr const &encryption_key) { + if (encryption_key && !encryption_key->empty()) { + duckdb_mbedtls::MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(data_ptr_cast(&(*encryption_key)[0]), + encryption_key->size()); + encryption_key->clear(); + } +} + void StorageManager::Initialize(QueryContext context) { bool in_memory = InMemory(); if (in_memory && read_only) { @@ -239,7 +346,7 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { // file does not exist and we are in read-write mode // create a new file - auto wal_path = GetWALPath(); + wal_path = GetWALPath(); // try to remove the WAL file if it exists fs.TryRemoveFile(wal_path); @@ -274,7 +381,8 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { sf_block_manager->CreateNewDatabase(context); block_manager = std::move(sf_block_manager); table_io_manager = make_uniq(*block_manager, row_group_size); - wal = make_uniq(db, wal_path); + wal = make_uniq(*this, wal_path); + } else { // Either the file exists, or we are in read-only mode, so we // try to read the existing file on disk. @@ -322,13 +430,41 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { } } - // load the db from storage + unique_ptr timer = nullptr; + + // Start timing the storage load step. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler; + timer = make_uniq(profiler->StartTimer(MetricType::ATTACH_LOAD_STORAGE_LATENCY)); + } + + // Load the checkpoint from storage. auto checkpoint_reader = SingleFileCheckpointReader(*this); checkpoint_reader.LoadFromStorage(); - auto wal_path = GetWALPath(); - wal = WriteAheadLog::Replay(fs, db, wal_path); + // End timing the storage load step. + if (timer) { + timer->EndTimer(); + timer = nullptr; + } + + // Start timing the WAL replay step. + if (client_context) { + auto profiler = client_context->client_data->profiler; + timer = make_uniq(profiler->StartTimer(MetricType::ATTACH_REPLAY_WAL_LATENCY)); + } + + // Replay the WAL. + wal_path = GetWALPath(); + wal = WriteAheadLog::Replay(context, *this, wal_path); + + // End timing the WAL replay step. + if (timer) { + timer->EndTimer(); + } } + if (row_group_size > 122880ULL && GetStorageVersion() < 4) { throw InvalidInputException("Unsupported row group size %llu - row group sizes >= 122_880 are only supported " "with STORAGE_VERSION '1.2.0' or above.\nExplicitly specify a newer storage " @@ -388,8 +524,15 @@ SingleFileStorageCommitState::~SingleFileStorageCommitState() { return; } try { - // truncate the WAL in case of a destructor + // Truncate the WAL in case of a destructor. RevertCommit(); + } catch (std::exception &ex) { + ErrorData data(ex); + try { + DUCKDB_LOG_ERROR(wal.GetDatabase().GetDatabase(), + "SingleFileStorageCommitState::~SingleFileStorageCommitState()\t\t" + data.Message()); + } catch (...) { // NOLINT + } } catch (...) { // NOLINT } } @@ -464,9 +607,9 @@ bool SingleFileStorageManager::IsCheckpointClean(MetaBlockPointer checkpoint_id) unique_ptr SingleFileStorageManager::CreateCheckpointWriter(QueryContext context, CheckpointOptions options) { if (InMemory()) { - return make_uniq(context, db, *block_manager, *this, options.type); + return make_uniq(context, db, *block_manager, *this, options); } - return make_uniq(context, db, *block_manager, options.type); + return make_uniq(context, db, *block_manager, options); } void SingleFileStorageManager::CreateCheckpoint(QueryContext context, CheckpointOptions options) { @@ -476,20 +619,27 @@ void SingleFileStorageManager::CreateCheckpoint(QueryContext context, Checkpoint if (db.GetStorageExtension()) { db.GetStorageExtension()->OnCheckpointStart(db, options); } + auto &config = DBConfig::Get(db); - if (GetWALSize() > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { - // we only need to checkpoint if there is anything in the WAL + // We only need to checkpoint if there is anything in the WAL. + auto wal_size = GetWALSize(); + if (wal_size > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { try { + // Start timing the checkpoint. + auto client_context = context.GetClientContext(); + if (client_context) { + auto profiler = client_context->client_data->profiler->StartTimer(MetricType::CHECKPOINT_LATENCY); + } + + // Write the checkpoint. auto checkpointer = CreateCheckpointWriter(context, options); checkpointer->CreateCheckpoint(); + } catch (std::exception &ex) { ErrorData error(ex); - throw FatalException("Failed to create checkpoint because of error: %s", error.RawMessage()); + throw FatalException("Failed to create checkpoint because of error: %s", error.Message()); } } - if (!InMemory() && options.wal_action == CheckpointWALAction::DELETE_WAL) { - ResetWAL(); - } if (db.GetStorageExtension()) { db.GetStorageExtension()->OnCheckpointEnd(db, options); diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index 7c8a12f13..9042934cc 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -8,20 +8,22 @@ namespace duckdb { -ArrayColumnData::ArrayColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type_p, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), - validity(block_manager, info, 0, start_row, *this) { +ArrayColumnData::ArrayColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + LogicalType type_p, ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, std::move(type_p), data_type, parent) { D_ASSERT(type.InternalType() == PhysicalType::ARRAY); - auto &child_type = ArrayType::GetChildType(type); - // the child column, with column index 1 (0 is the validity mask) - child_column = ColumnData::CreateColumnUnique(block_manager, info, 1, start_row, child_type, this); + if (data_type != ColumnDataType::CHECKPOINT_TARGET) { + auto &child_type = ArrayType::GetChildType(type); + validity = make_shared_ptr(block_manager, info, 0, *this); + // the child column, with column index 1 (0 is the validity mask) + child_column = CreateColumn(block_manager, info, 1, child_type, data_type, this); + } } -void ArrayColumnData::SetStart(idx_t new_start) { - this->start = new_start; - child_column->SetStart(new_start); - validity.SetStart(new_start); +void ArrayColumnData::SetDataType(ColumnDataType data_type) { + ColumnData::SetDataType(data_type); + child_column->SetDataType(data_type); + validity->SetDataType(data_type); } FilterPropagateResult ArrayColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { @@ -32,7 +34,7 @@ FilterPropagateResult ArrayColumnData::CheckZonemap(ColumnScanState &state, Tabl void ArrayColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { ColumnData::InitializePrefetch(prefetch_state, scan_state, rows); - validity.InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); + validity->InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); auto array_size = ArrayType::GetSize(type); child_column->InitializePrefetch(prefetch_state, scan_state.child_states[1], rows * array_size); } @@ -41,10 +43,10 @@ void ArrayColumnData::InitializeScan(ColumnScanState &state) { // initialize the validity segment D_ASSERT(state.child_states.size() == 2); - state.row_index = 0; + state.offset_in_column = 0; state.current = nullptr; - validity.InitializeScan(state.child_states[0]); + validity->InitializeScan(state.child_states[0]); // initialize the child scan child_column->InitializeScan(state.child_states[1]); @@ -59,18 +61,18 @@ void ArrayColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row return; } - state.row_index = row_idx; + state.offset_in_column = row_idx; state.current = nullptr; // initialize the validity segment - validity.InitializeScanWithOffset(state.child_states[0], row_idx); + validity->InitializeScanWithOffset(state.child_states[0], row_idx); auto array_size = ArrayType::GetSize(type); - auto child_count = (row_idx - start) * array_size; + auto child_count = row_idx * array_size; D_ASSERT(child_count <= child_column->GetMaxEntry()); if (child_count < child_column->GetMaxEntry()) { - const auto child_offset = start + child_count; + const auto child_offset = child_count; child_column->InitializeScanWithOffset(state.child_states[1], child_offset); } } @@ -87,7 +89,7 @@ idx_t ArrayColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, idx_t ArrayColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset) { // Scan validity - auto scan_count = validity.ScanCount(state.child_states[0], result, count, result_offset); + auto scan_count = validity->ScanCount(state.child_states[0], result, count, result_offset); auto array_size = ArrayType::GetSize(type); // Scan child column auto &child_vec = ArrayVector::GetEntry(result); @@ -120,7 +122,7 @@ void ArrayColumnData::Select(TransactionData transaction, idx_t vector_index, Co // not consecutive - break break; } - end_idx = next_idx; + end_idx = next_idx + 1; } consecutive_ranges++; } @@ -155,12 +157,12 @@ void ArrayColumnData::Select(TransactionData transaction, idx_t vector_index, Co if (start_idx > current_position) { // skip forward idx_t skip_amount = start_idx - current_position; - validity.Skip(state.child_states[0], skip_amount); + validity->Skip(state.child_states[0], skip_amount); child_column->Skip(state.child_states[1], skip_amount * array_size); } // scan into the result array idx_t scan_count = end_idx - start_idx; - validity.ScanCount(state.child_states[0], result, scan_count, current_offset); + validity->ScanCount(state.child_states[0], result, scan_count, current_offset); child_column->ScanCount(state.child_states[1], child_vec, scan_count * array_size, current_offset * array_size); // move the current position forward current_offset += scan_count; @@ -169,14 +171,14 @@ void ArrayColumnData::Select(TransactionData transaction, idx_t vector_index, Co // if there is any remaining at the end - skip any trailing rows if (current_position < target_count) { idx_t skip_amount = target_count - current_position; - validity.Skip(state.child_states[0], skip_amount); + validity->Skip(state.child_states[0], skip_amount); child_column->Skip(state.child_states[1], skip_amount * array_size); } } void ArrayColumnData::Skip(ColumnScanState &state, idx_t count) { // Skip validity - validity.Skip(state.child_states[0], count); + validity->Skip(state.child_states[0], count); // Skip child column auto array_size = ArrayType::GetSize(type); child_column->Skip(state.child_states[1], count * array_size); @@ -184,7 +186,7 @@ void ArrayColumnData::Skip(ColumnScanState &state, idx_t count) { void ArrayColumnData::InitializeAppend(ColumnAppendState &state) { ColumnAppendState validity_append; - validity.InitializeAppend(validity_append); + validity->InitializeAppend(validity_append); state.child_appends.push_back(std::move(validity_append)); ColumnAppendState child_append; @@ -201,7 +203,7 @@ void ArrayColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Ve } // Append validity - validity.Append(stats, state.child_appends[0], vector, count); + validity->Append(stats, state.child_appends[0], vector, count); // Append child column auto array_size = ArrayType::GetSize(type); auto &child_vec = ArrayVector::GetEntry(vector); @@ -210,27 +212,28 @@ void ArrayColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Ve this->count += count; } -void ArrayColumnData::RevertAppend(row_t start_row) { +void ArrayColumnData::RevertAppend(row_t new_count) { // Revert validity - validity.RevertAppend(start_row); + validity->RevertAppend(new_count); // Revert child column auto array_size = ArrayType::GetSize(type); - child_column->RevertAppend(start_row * UnsafeNumericCast(array_size)); + child_column->RevertAppend(new_count * UnsafeNumericCast(array_size)); - this->count = UnsafeNumericCast(start_row) - this->start; + this->count = UnsafeNumericCast(new_count); } idx_t ArrayColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { throw NotImplementedException("Array Fetch"); } -void ArrayColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ArrayColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { throw NotImplementedException("Array Update is not supported."); } -void ArrayColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ArrayColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { throw NotImplementedException("Array Update Column is not supported"); } @@ -240,14 +243,13 @@ unique_ptr ArrayColumnData::GetUpdateStatistics() { void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - // Create state for validity & child column if (state.child_states.empty()) { state.child_states.push_back(make_uniq()); } // Fetch validity - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); // Fetch child column auto &child_vec = ArrayVector::GetEntry(result); @@ -255,24 +257,41 @@ void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &st auto array_size = ArrayType::GetSize(type); // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) - auto child_state = make_uniq(); - child_state->Initialize(child_type, nullptr); + ColumnScanState child_state(nullptr); + child_state.Initialize(state.context, child_type, nullptr); - const auto child_offset = start + (UnsafeNumericCast(row_id) - start) * array_size; + const auto child_offset = UnsafeNumericCast(row_id) * array_size; - child_column->InitializeScanWithOffset(*child_state, child_offset); + child_column->InitializeScanWithOffset(child_state, child_offset); Vector child_scan(child_type, array_size); - child_column->ScanCount(*child_state, child_scan, array_size); + child_column->ScanCount(child_state, child_scan, array_size); VectorOperations::Copy(child_scan, child_vec, array_size, 0, result_idx * array_size); } -void ArrayColumnData::CommitDropColumn() { - validity.CommitDropColumn(); - child_column->CommitDropColumn(); +void ArrayColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + validity->VisitBlockIds(visitor); + child_column->VisitBlockIds(visitor); +} + +void ArrayColumnData::SetValidityData(shared_ptr validity_p) { + if (validity) { + throw InternalException("ArrayColumnData::SetValidityData cannot be used to overwrite existing validity"); + } + validity_p->SetParent(this); + this->validity = std::move(validity_p); +} + +void ArrayColumnData::SetChildData(shared_ptr child_column_p) { + if (child_column) { + throw InternalException("ArrayColumnData::SetChildData cannot be used to overwrite existing data"); + } + child_column_p->SetParent(this); + this->child_column = std::move(child_column_p); } struct ArrayColumnCheckpointState : public ColumnCheckpointState { - ArrayColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) + ArrayColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) : ColumnCheckpointState(row_group, column_data, partial_block_manager) { global_stats = ArrayStats::CreateEmpty(column_data.type).ToUnique(); } @@ -281,69 +300,87 @@ struct ArrayColumnCheckpointState : public ColumnCheckpointState { unique_ptr child_state; public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, original_column.type, + ColumnDataType::CHECKPOINT_TARGET, nullptr); + } + + shared_ptr GetFinalResult() override { + if (!result_column) { + result_column = CreateEmptyColumnData(); + } + auto &column_data = result_column->Cast(); + auto validity_child = validity_state->GetFinalResult(); + column_data.SetValidityData(shared_ptr_cast(std::move(validity_child))); + column_data.SetChildData(child_state->GetFinalResult()); + return ColumnCheckpointState::GetFinalResult(); + } + unique_ptr GetStatistics() override { auto stats = global_stats->Copy(); + stats.Merge(*validity_state->GetStatistics()); ArrayStats::SetChildStats(stats, child_state->GetStatistics()); return stats.ToUnique(); } PersistentColumnData ToPersistentData() override { - PersistentColumnData data(PhysicalType::ARRAY); + PersistentColumnData data(original_column.type); data.child_columns.push_back(validity_state->ToPersistentData()); data.child_columns.push_back(child_state->ToPersistentData()); return data; } }; -unique_ptr ArrayColumnData::CreateCheckpointState(RowGroup &row_group, +unique_ptr ArrayColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { return make_uniq(row_group, *this, partial_block_manager); } -unique_ptr ArrayColumnData::Checkpoint(RowGroup &row_group, +unique_ptr ArrayColumnData::Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); - checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); + checkpoint_state->validity_state = validity->Checkpoint(row_group, checkpoint_info); checkpoint_state->child_state = child_column->Checkpoint(row_group, checkpoint_info); return std::move(checkpoint_state); } bool ArrayColumnData::IsPersistent() { - return validity.IsPersistent() && child_column->IsPersistent(); + return validity->IsPersistent() && child_column->IsPersistent(); } bool ArrayColumnData::HasAnyChanges() const { - return child_column->HasAnyChanges() || validity.HasAnyChanges(); + return child_column->HasAnyChanges() || validity->HasAnyChanges(); } PersistentColumnData ArrayColumnData::Serialize() { - PersistentColumnData persistent_data(PhysicalType::ARRAY); - persistent_data.child_columns.push_back(validity.Serialize()); + PersistentColumnData persistent_data(type); + persistent_data.child_columns.push_back(validity->Serialize()); persistent_data.child_columns.push_back(child_column->Serialize()); return persistent_data; } void ArrayColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) { D_ASSERT(column_data.pointers.empty()); - validity.InitializeColumn(column_data.child_columns[0], target_stats); + validity->InitializeColumn(column_data.child_columns[0], target_stats); auto &child_stats = ArrayStats::GetChildStats(target_stats); child_column->InitializeColumn(column_data.child_columns[1], child_stats); - this->count = validity.count.load(); + this->count = validity->count.load(); } -void ArrayColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +void ArrayColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } void ArrayColumnData::Verify(RowGroup &parent) { #ifdef DEBUG ColumnData::Verify(parent); - validity.Verify(parent); + validity->Verify(parent); child_column->Verify(parent); #endif } diff --git a/src/duckdb/src/storage/table/chunk_info.cpp b/src/duckdb/src/storage/table/chunk_info.cpp index 3b7b11d7b..dfef0b4a1 100644 --- a/src/duckdb/src/storage/table/chunk_info.cpp +++ b/src/duckdb/src/storage/table/chunk_info.cpp @@ -1,10 +1,12 @@ #include "duckdb/storage/table/chunk_info.hpp" + #include "duckdb/transaction/transaction.hpp" #include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/transaction/delete_info.hpp" +#include "duckdb/execution/index/fixed_size_allocator.hpp" namespace duckdb { @@ -32,7 +34,7 @@ static bool UseVersion(TransactionData transaction, transaction_t id) { return TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, id); } -bool ChunkInfo::Cleanup(transaction_t lowest_transaction, unique_ptr &result) const { +bool ChunkInfo::Cleanup(transaction_t lowest_transaction) const { return false; } @@ -40,7 +42,7 @@ void ChunkInfo::Write(WriteStream &writer) const { writer.Write(type); } -unique_ptr ChunkInfo::Read(ReadStream &reader) { +unique_ptr ChunkInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto type = reader.Read(); switch (type) { case ChunkInfoType::EMPTY_INFO: @@ -48,7 +50,7 @@ unique_ptr ChunkInfo::Read(ReadStream &reader) { case ChunkInfoType::CONSTANT_INFO: return ChunkConstantInfo::Read(reader); case ChunkInfoType::VECTOR_INFO: - return ChunkVectorInfo::Read(reader); + return ChunkVectorInfo::Read(allocator, reader); default: throw SerializationException("Could not deserialize Chunk Info Type: unrecognized type"); } @@ -71,7 +73,7 @@ idx_t ChunkConstantInfo::TemplatedGetSelVector(transaction_t start_time, transac return 0; } -idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkConstantInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return TemplatedGetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } @@ -81,6 +83,13 @@ idx_t ChunkConstantInfo::GetCommittedSelVector(transaction_t min_start_id, trans return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); } +idx_t ChunkConstantInfo::GetCheckpointRowCount(TransactionData transaction, idx_t max_count) { + if (TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, insert_id)) { + return max_count; + } + return 0; +} + bool ChunkConstantInfo::Fetch(TransactionData transaction, row_t row) { return UseVersion(transaction, insert_id) && !UseVersion(transaction, delete_id); } @@ -95,11 +104,11 @@ bool ChunkConstantInfo::HasDeletes() const { return is_deleted; } -idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) { +idx_t ChunkConstantInfo::GetCommittedDeletedCount(idx_t max_count) const { return delete_id < TRANSACTION_ID_START ? max_count : 0; } -bool ChunkConstantInfo::Cleanup(transaction_t lowest_transaction, unique_ptr &result) const { +bool ChunkConstantInfo::Cleanup(transaction_t lowest_transaction) const { if (delete_id != NOT_DELETED_ID) { // the chunk info is labeled as deleted - we need to keep it around return false; @@ -125,52 +134,85 @@ unique_ptr ChunkConstantInfo::Read(ReadStream &reader) { return std::move(info); } +string ChunkConstantInfo::ToString(idx_t max_count) const { + string result; + result += "Constant [Count: " + to_string(max_count); + result += ", "; + result += "Insert Id: " + to_string(insert_id); + if (delete_id != NOT_DELETED_ID) { + result += ", Delete Id: " + to_string(delete_id); + } + result += "]"; + return result; +} + //===--------------------------------------------------------------------===// // Vector info //===--------------------------------------------------------------------===// -ChunkVectorInfo::ChunkVectorInfo(idx_t start) - : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), insert_id(0), same_inserted_id(true), any_deleted(false) { - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - inserted[i] = 0; - deleted[i] = NOT_DELETED_ID; +ChunkVectorInfo::ChunkVectorInfo(FixedSizeAllocator &allocator_p, idx_t start, transaction_t insert_id_p) + : ChunkInfo(start, ChunkInfoType::VECTOR_INFO), allocator(allocator_p), constant_insert_id(insert_id_p) { +} + +ChunkVectorInfo::~ChunkVectorInfo() { + if (AnyDeleted()) { + allocator.Free(deleted_data); + } + if (!HasConstantInsertionId()) { + allocator.Free(inserted_data); } } template idx_t ChunkVectorInfo::TemplatedGetSelVector(transaction_t start_time, transaction_t transaction_id, SelectionVector &sel_vector, idx_t max_count) const { - idx_t count = 0; - if (same_inserted_id && !any_deleted) { - // all tuples have the same inserted id: and no tuples were deleted - if (OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { - return max_count; - } else { - return 0; + if (HasConstantInsertionId()) { + if (!AnyDeleted()) { + // all tuples have the same inserted id: and no tuples were deleted + if (OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { + return max_count; + } else { + return 0; + } } - } else if (same_inserted_id) { - if (!OP::UseInsertedVersion(start_time, transaction_id, insert_id)) { + if (!OP::UseInsertedVersion(start_time, transaction_id, ConstantInsertId())) { return 0; } // have to check deleted flag + idx_t count = 0; + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < max_count; i++) { if (OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { sel_vector.set_index(count++, i); } } - } else if (!any_deleted) { + return count; + } + if (!AnyDeleted()) { // have to check inserted flag + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + idx_t count = 0; for (idx_t i = 0; i < max_count; i++) { if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i])) { sel_vector.set_index(count++, i); } } - } else { - // have to check both flags - for (idx_t i = 0; i < max_count; i++) { - if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && - OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { - sel_vector.set_index(count++, i); - } + return count; + } + + idx_t count = 0; + // have to check both flags + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + for (idx_t i = 0; i < max_count; i++) { + if (OP::UseInsertedVersion(start_time, transaction_id, inserted[i]) && + OP::UseDeletedVersion(start_time, transaction_id, deleted[i])) { + sel_vector.set_index(count++, i); } } return count; @@ -186,16 +228,101 @@ idx_t ChunkVectorInfo::GetCommittedSelVector(transaction_t min_start_id, transac return TemplatedGetSelVector(min_start_id, min_transaction_id, sel_vector, max_count); } -idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) { +idx_t ChunkVectorInfo::GetSelVector(TransactionData transaction, SelectionVector &sel_vector, idx_t max_count) const { return GetSelVector(transaction.start_time, transaction.transaction_id, sel_vector, max_count); } +idx_t ChunkVectorInfo::GetCheckpointRowCount(TransactionData transaction, idx_t max_count) { + if (HasConstantInsertionId()) { + if (!TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, + ConstantInsertId())) { + return 0; + } + return max_count; + } + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + + idx_t count = 0; + for (idx_t i = 0; i < max_count; i++) { + if (!TransactionVersionOperator::UseInsertedVersion(transaction.start_time, transaction.transaction_id, + inserted[i])) { + continue; + } + if (i != count) { + throw InternalException("Error in ChunkVectorInfo::GetCheckpointRowCount - insertions are not sequential"); + } + count++; + } + return count; +} + bool ChunkVectorInfo::Fetch(TransactionData transaction, row_t row) { - return UseVersion(transaction, inserted[row]) && !UseVersion(transaction, deleted[row]); + transaction_t fetch_insert_id; + transaction_t fetch_deleted_id; + if (HasConstantInsertionId()) { + fetch_insert_id = ConstantInsertId(); + } else { + auto insert_segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = insert_segment.GetPtr(); + fetch_insert_id = inserted[row]; + } + if (!AnyDeleted()) { + fetch_deleted_id = NOT_DELETED_ID; + } else { + auto delete_segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = delete_segment.GetPtr(); + fetch_deleted_id = deleted[row]; + } + + return UseVersion(transaction, fetch_insert_id) && !UseVersion(transaction, fetch_deleted_id); +} + +IndexPointer ChunkVectorInfo::GetInsertedPointer() const { + if (HasConstantInsertionId()) { + throw InternalException("ChunkVectorInfo: insert id requested but insertions were not initialized"); + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetDeletedPointer() const { + if (!AnyDeleted()) { + throw InternalException("ChunkVectorInfo: deleted id requested but deletions were not initialized"); + } + return deleted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedInsertedPointer() { + if (HasConstantInsertionId()) { + transaction_t constant_id = ConstantInsertId(); + + inserted_data = allocator.New(); + inserted_data.SetMetadata(1); + auto segment = allocator.GetHandle(inserted_data); + auto inserted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + inserted[i] = constant_id; + } + } + return inserted_data; +} + +IndexPointer ChunkVectorInfo::GetInitializedDeletedPointer() { + if (!AnyDeleted()) { + deleted_data = allocator.New(); + deleted_data.SetMetadata(1); + auto segment = allocator.GetHandle(deleted_data); + auto deleted = segment.GetPtr(); + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + deleted[i] = NOT_DELETED_ID; + } + } + return deleted_data; } idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t count) { - any_deleted = true; + auto segment = allocator.GetHandle(GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); idx_t deleted_tuples = 0; for (idx_t i = 0; i < count; i++) { @@ -220,6 +347,9 @@ idx_t ChunkVectorInfo::Delete(transaction_t transaction_id, row_t rows[], idx_t } void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &info) { + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); + if (info.is_consecutive) { for (idx_t i = 0; i < info.count; i++) { deleted[i] = commit_id; @@ -234,32 +364,45 @@ void ChunkVectorInfo::CommitDelete(transaction_t commit_id, const DeleteInfo &in void ChunkVectorInfo::Append(idx_t start, idx_t end, transaction_t commit_id) { if (start == 0) { - insert_id = commit_id; - } else if (insert_id != commit_id) { - same_inserted_id = false; - insert_id = NOT_DELETED_ID; + // first insert to this vector - just assign the commit id + constant_insert_id = commit_id; + return; } + if (HasConstantInsertionId() && ConstantInsertId() == commit_id) { + // we are inserting again, but we have the same id as before - still the same insert id + return; + } + + auto segment = allocator.GetHandle(GetInitializedInsertedPointer()); + auto inserted = segment.GetPtr(); for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } void ChunkVectorInfo::CommitAppend(transaction_t commit_id, idx_t start, idx_t end) { - if (same_inserted_id) { - insert_id = commit_id; + if (HasConstantInsertionId()) { + constant_insert_id = commit_id; + return; } + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t i = start; i < end; i++) { inserted[i] = commit_id; } } -bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr &result) const { - if (any_deleted) { +bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction) const { + if (AnyDeleted()) { // if any rows are deleted we can't clean-up return false; } // check if the insertion markers have to be used by all transactions going forward - if (!same_inserted_id) { + if (!HasConstantInsertionId()) { + auto segment = allocator.GetHandle(GetInsertedPointer()); + auto inserted = segment.GetPtr(); + for (idx_t idx = 1; idx < STANDARD_VECTOR_SIZE; idx++) { if (inserted[idx] > lowest_transaction) { // transaction was inserted after the lowest transaction start @@ -267,7 +410,7 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr lowest_transaction) { + } else if (ConstantInsertId() > lowest_transaction) { // transaction was inserted after the lowest transaction start // we still need to use an older version - cannot compress return false; @@ -276,13 +419,54 @@ bool ChunkVectorInfo::Cleanup(transaction_t lowest_transaction, unique_ptr(); + + for (idx_t idx = 0; idx < max_count; idx++) { + if (idx > 0) { + result += ", "; + } + result += to_string(inserted[idx]); + } + result += "]"; + } + result += "]"; + return result; +} + +transaction_t ChunkVectorInfo::ConstantInsertId() const { + if (!HasConstantInsertionId()) { + throw InternalException("ConstantInsertId() called but vector info does not have a constant insertion id"); + } + return constant_insert_id; +} + +idx_t ChunkVectorInfo::GetCommittedDeletedCount(idx_t max_count) const { + if (!AnyDeleted()) { return 0; } + auto segment = allocator.GetHandle(GetDeletedPointer()); + auto deleted = segment.GetPtr(); + idx_t delete_count = 0; for (idx_t i = 0; i < max_count; i++) { if (deleted[i] < TRANSACTION_ID_START) { @@ -319,15 +503,17 @@ void ChunkVectorInfo::Write(WriteStream &writer) const { mask.Write(writer, STANDARD_VECTOR_SIZE); } -unique_ptr ChunkVectorInfo::Read(ReadStream &reader) { +unique_ptr ChunkVectorInfo::Read(FixedSizeAllocator &allocator, ReadStream &reader) { auto start = reader.Read(); - auto result = make_uniq(start); - result->any_deleted = true; + auto result = make_uniq(allocator, start); ValidityMask mask; mask.Read(reader, STANDARD_VECTOR_SIZE); + + auto segment = allocator.GetHandle(result->GetInitializedDeletedPointer()); + auto deleted = segment.GetPtr(); for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { if (mask.RowIsValid(i)) { - result->deleted[i] = 0; + deleted[i] = 0; } } return std::move(result); diff --git a/src/duckdb/src/storage/table/column_checkpoint_state.cpp b/src/duckdb/src/storage/table/column_checkpoint_state.cpp index 213338d97..ddd40d8c2 100644 --- a/src/duckdb/src/storage/table/column_checkpoint_state.cpp +++ b/src/duckdb/src/storage/table/column_checkpoint_state.cpp @@ -9,9 +9,10 @@ namespace duckdb { -ColumnCheckpointState::ColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, +ColumnCheckpointState::ColumnCheckpointState(const RowGroup &row_group, ColumnData &original_column, PartialBlockManager &partial_block_manager) - : row_group(row_group), column_data(column_data), partial_block_manager(partial_block_manager) { + : row_group(row_group), original_column(original_column), partial_block_manager(partial_block_manager), + original_column_mutable(original_column) { } ColumnCheckpointState::~ColumnCheckpointState() { @@ -22,6 +23,27 @@ unique_ptr ColumnCheckpointState::GetStatistics() { return std::move(global_stats); } +shared_ptr ColumnCheckpointState::CreateEmptyColumnData() { + throw InternalException("CreateEmptyColumnData not implemented for this column checkpoint state"); +} + +ColumnData &ColumnCheckpointState::GetResultColumn() { + if (!result_column) { + result_column = CreateEmptyColumnData(); + } + return *result_column; +} + +shared_ptr ColumnCheckpointState::GetFinalResult() { + if (!result_column) { + // no result column instantiated - that means we haven't changed anything and can directly return the + // original column + return original_column_mutable.shared_from_this(); + } + result_column->SetCount(original_column.count.load()); + return result_column; +} + PartialBlockForCheckpoint::PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, BlockManager &block_manager) : PartialBlock(state, block_manager, segment.block) { @@ -136,10 +158,9 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme if (segment->stats.statistics.IsConstant()) { // Constant block. segment->ConvertToPersistent(partial_block_manager.GetClientContext(), nullptr, INVALID_BLOCK); - } else if (segment_size != 0) { // Non-constant block with data that has to go to disk. - auto &db = column_data.GetDatabase(); + auto &db = original_column.GetDatabase(); auto &buffer_manager = BufferManager::GetBufferManager(db); partial_block_lock = partial_block_manager.GetLock(); @@ -158,7 +179,7 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme auto new_handle = buffer_manager.Pin(pstate.block_handle); // memcpy the contents of the old block to the new block memcpy(new_handle.Ptr() + offset_in_block, old_handle.Ptr(), segment_size); - pstate.AddSegmentToTail(column_data, *segment, offset_in_block); + pstate.AddSegmentToTail(*result_column, *segment, offset_in_block); } else { // Create a new block for future reuse. if (segment->SegmentSize() != block_size) { @@ -168,8 +189,8 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme segment->Resize(block_size); } D_ASSERT(offset_in_block == 0); - allocation.partial_block = partial_block_manager.CreatePartialBlock(column_data, *segment, allocation.state, - *allocation.block_manager); + allocation.partial_block = partial_block_manager.CreatePartialBlock( + *result_column, *segment, allocation.state, *allocation.block_manager); } // Writer will decide whether to reuse this block. partial_block_manager.RegisterPartialBlock(std::move(allocation)); @@ -185,7 +206,7 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme DataPointer data_pointer(segment->stats.statistics.Copy()); data_pointer.block_pointer.block_id = block_id; data_pointer.block_pointer.offset = offset_in_block; - data_pointer.row_start = row_group.start; + data_pointer.row_start = 0; if (!data_pointers.empty()) { auto &last_pointer = data_pointers.back(); data_pointer.row_start = last_pointer.row_start + last_pointer.tuple_count; @@ -198,12 +219,13 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme } // append the segment to the new segment tree - new_tree.AppendSegment(std::move(segment)); + GetResultColumn().GetSegmentTree().AppendSegment(std::move(segment)); data_pointers.push_back(std::move(data_pointer)); } PersistentColumnData ColumnCheckpointState::ToPersistentData() { - PersistentColumnData data(column_data.type.InternalType()); + auto &type = result_column ? result_column->type : original_column.type; + PersistentColumnData data(type); data.pointers = std::move(data_pointers); return data; } diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index c212fcb18..fa692e8d2 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -11,6 +11,7 @@ #include "duckdb/storage/table/standard_column_data.hpp" #include "duckdb/storage/table/array_column_data.hpp" #include "duckdb/storage/table/struct_column_data.hpp" +#include "duckdb/storage/table/variant_column_data.hpp" #include "duckdb/storage/table/update_segment.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/storage/table/append_state.hpp" @@ -18,13 +19,16 @@ #include "duckdb/common/serializer/read_stream.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" namespace duckdb { -ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type_p, optional_ptr parent) - : start(start_row), count(0), block_manager(block_manager), info(info), column_index(column_index), - type(std::move(type_p)), allocation_size(0), parent(parent) { +ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type_p, + ColumnDataType data_type_p, optional_ptr parent_p) + : count(0), block_manager(block_manager), info(info), column_index(column_index), type(std::move(type_p)), + allocation_size(0), + data_type(data_type_p == ColumnDataType::CHECKPOINT_TARGET ? ColumnDataType::MAIN_TABLE : data_type_p), + parent(parent_p) { if (!parent) { stats = make_uniq(type); } @@ -33,14 +37,8 @@ ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t c ColumnData::~ColumnData() { } -void ColumnData::SetStart(idx_t new_start) { - this->start = new_start; - idx_t offset = 0; - for (auto &segment : data.Segments()) { - segment.start = start + offset; - offset += segment.count; - } - data.Reinitialize(); +void ColumnData::SetDataType(ColumnDataType data_type_p) { + this->data_type = data_type_p; } DatabaseInstance &ColumnData::GetDatabase() const { @@ -55,13 +53,6 @@ StorageManager &ColumnData::GetStorageManager() const { return info.GetDB().GetStorageManager(); } -const LogicalType &ColumnData::RootType() const { - if (parent) { - return parent->RootType(); - } - return type; -} - bool ColumnData::HasUpdates() const { lock_guard update_guard(update_lock); return updates.get(); @@ -78,17 +69,15 @@ bool ColumnData::HasChanges(idx_t start_row, idx_t end_row) const { } bool ColumnData::HasChanges() const { - auto l = data.Lock(); - auto &nodes = data.ReferenceLoadedSegments(l); - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->segment_type == ColumnSegmentType::TRANSIENT) { + for (auto &segment_node : data.SegmentNodes()) { + auto &segment = segment_node.GetNode(); + if (segment.segment_type == ColumnSegmentType::TRANSIENT) { // transient segment: always need to write to disk return true; } // persistent segment; check if there were any updates or deletions in this segment - idx_t start_row_idx = segment->start - start; - idx_t end_row_idx = start_row_idx + segment->count; + idx_t start_row_idx = segment_node.GetRowStart(); + idx_t end_row_idx = start_row_idx + segment.count; if (HasChanges(start_row_idx, end_row_idx)) { return true; } @@ -100,11 +89,6 @@ bool ColumnData::HasAnyChanges() const { return HasChanges(); } -void ColumnData::ClearUpdates() { - lock_guard update_guard(update_lock); - updates.reset(); -} - idx_t ColumnData::GetMaxEntry() { return count; } @@ -112,18 +96,21 @@ idx_t ColumnData::GetMaxEntry() { void ColumnData::InitializeScan(ColumnScanState &state) { state.current = data.GetRootSegment(); state.segment_tree = &data; - state.row_index = state.current ? state.current->start : 0; - state.internal_index = state.row_index; + state.offset_in_column = state.current ? state.current->GetRowStart() : 0; + state.internal_index = state.offset_in_column; state.initialized = false; state.scan_state.reset(); state.last_offset = 0; } void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + if (row_idx > count) { + throw InternalException("row_idx in InitializeScanWithOffset out of range"); + } state.current = data.GetSegment(row_idx); state.segment_tree = &data; - state.row_index = row_idx; - state.internal_index = state.current->start; + state.offset_in_column = row_idx; + state.internal_index = state.current->GetRowStart(); state.initialized = false; state.scan_state.reset(); state.last_offset = 0; @@ -139,7 +126,8 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_ return ScanVectorType::SCAN_FLAT_VECTOR; } // check if the current segment has enough data remaining - idx_t remaining_in_segment = state.current->start + state.current->count - state.row_index; + auto ¤t = state.current->GetNode(); + idx_t remaining_in_segment = state.current->GetRowStart() + current.count - state.offset_in_column; if (remaining_in_segment < scan_count) { // there is not enough data remaining in the current segment so we need to scan across segments // we need flat vectors here @@ -155,38 +143,42 @@ void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanSta } if (!scan_state.initialized) { // need to prefetch for the current segment if we have not yet initialized the scan for this segment - scan_state.current->InitializePrefetch(prefetch_state, scan_state); + current_segment->GetNode().InitializePrefetch(prefetch_state, scan_state); } - idx_t row_index = scan_state.row_index; + idx_t row_index = scan_state.offset_in_column; while (remaining > 0) { - idx_t scan_count = MinValue(remaining, current_segment->start + current_segment->count - row_index); + auto ¤t = current_segment->GetNode(); + idx_t scan_count = MinValue(remaining, current_segment->GetRowStart() + current.count - row_index); remaining -= scan_count; row_index += scan_count; if (remaining > 0) { - auto next = data.GetNextSegment(current_segment); + auto next = data.GetNextSegment(*current_segment); if (!next) { break; } - next->InitializePrefetch(prefetch_state, scan_state); + next->GetNode().InitializePrefetch(prefetch_state, scan_state); current_segment = next; } } } void ColumnData::BeginScanVectorInternal(ColumnScanState &state) { + D_ASSERT(state.current); + state.previous_states.clear(); if (!state.initialized) { - D_ASSERT(state.current); - state.current->InitializeScan(state); - state.internal_index = state.current->start; + auto ¤t = state.current->GetNode(); + current.InitializeScan(state); + state.internal_index = state.current->GetRowStart(); state.initialized = true; } - D_ASSERT(data.HasSegment(state.current)); - D_ASSERT(state.internal_index <= state.row_index); - if (state.internal_index < state.row_index) { - state.current->Skip(state); + D_ASSERT(data.HasSegment(*state.current)); + D_ASSERT(state.internal_index <= state.offset_in_column); + if (state.internal_index < state.offset_in_column) { + auto ¤t = state.current->GetNode(); + current.Skip(state); } - D_ASSERT(state.current->type == type); + D_ASSERT(state.current->GetNode().type == type); } idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, ScanVectorType scan_type, @@ -197,70 +189,73 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai BeginScanVectorInternal(state); idx_t initial_remaining = remaining; while (remaining > 0) { - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); - idx_t scan_count = MinValue(remaining, state.current->start + state.current->count - state.row_index); + auto ¤t = state.current->GetNode(); + auto current_start = state.current->GetRowStart(); + D_ASSERT(state.offset_in_column >= current_start && state.offset_in_column <= current_start + current.count); + idx_t scan_count = MinValue(remaining, current_start + current.count - state.offset_in_column); idx_t result_offset = base_result_offset + initial_remaining - remaining; if (scan_count > 0) { if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < scan_count; i++) { ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, - result_offset + i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.offset_in_column + i - current_start), + result, result_offset + i); } } else { - state.current->Scan(state, scan_count, result, result_offset, scan_type); + current.Scan(state, scan_count, result, result_offset, scan_type); } - state.row_index += scan_count; + state.offset_in_column += scan_count; remaining -= scan_count; } if (remaining > 0) { - auto next = data.GetNextSegment(state.current); + auto next = data.GetNextSegment(*state.current); if (!next) { break; } state.previous_states.emplace_back(std::move(state.scan_state)); state.current = next; - state.current->InitializeScan(state); + state.current->GetNode().InitializeScan(state); state.segment_checked = false; - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); + D_ASSERT(state.offset_in_column >= state.current->GetRowStart() && + state.offset_in_column <= state.current->GetRowStart() + state.current->GetNode().count); } } - state.internal_index = state.row_index; + state.internal_index = state.offset_in_column; return initial_remaining - remaining; } void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t target_count, const SelectionVector &sel, idx_t sel_count) { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = state.current->GetNode(); + if (state.current->GetRowStart() + current.count - state.offset_in_column < target_count) { throw InternalException("ColumnData::SelectVector should be able to fetch everything from one segment"); } if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < sel_count; i++) { auto source_idx = sel.get_index(i); ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + source_idx), result, i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.offset_in_column + source_idx), result, i); } } else { - state.current->Select(state, target_count, result, sel, sel_count); + current.Select(state, target_count, result, sel, sel_count); } - state.row_index += target_count; - state.internal_index = state.row_index; + state.offset_in_column += target_count; + state.internal_index = state.offset_in_column; } void ColumnData::FilterVector(ColumnScanState &state, Vector &result, idx_t target_count, SelectionVector &sel, idx_t &sel_count, const TableFilter &filter, TableFilterState &filter_state) { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = state.current->GetNode(); + if (state.current->GetRowStart() + current.count - state.offset_in_column < target_count) { throw InternalException("ColumnData::Filter should be able to fetch everything from one segment"); } - state.current->Filter(state, target_count, result, sel, sel_count, filter, filter_state); - state.row_index += target_count; - state.internal_index = state.row_index; + current.Filter(state, target_count, result, sel, sel_count, filter, filter_state); + state.offset_in_column += target_count; + state.internal_index = state.offset_in_column; } unique_ptr ColumnData::GetUpdateStatistics() { @@ -293,13 +288,15 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); } -void ColumnData::UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector) { +void ColumnData::UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, Vector &base_vector, + idx_t row_group_start) { lock_guard update_guard(update_lock); if (!updates) { updates = make_uniq(*this); } - updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); + updates->Update(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector, + row_group_start); } idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, @@ -348,8 +345,8 @@ idx_t ColumnData::GetVectorCount(idx_t vector_index) const { } void ColumnData::ScanCommittedRange(idx_t row_group_start, idx_t offset_in_row_group, idx_t s_count, Vector &result) { - ColumnScanState child_state; - InitializeScanWithOffset(child_state, row_group_start + offset_in_row_group); + ColumnScanState child_state(nullptr); + InitializeScanWithOffset(child_state, offset_in_row_group); bool has_updates = HasUpdates(); auto scan_count = ScanVector(child_state, result, s_count, ScanVectorType::SCAN_FLAT_VECTOR); if (has_updates) { @@ -401,7 +398,7 @@ void ColumnData::Append(BaseStatistics &append_stats, ColumnAppendState &state, } void ColumnData::Append(ColumnAppendState &state, Vector &vector, idx_t append_count) { - if (parent || !stats) { + if (!stats) { throw InternalException("ColumnData::Append called on a column with a parent or without stats"); } lock_guard l(stats_lock); @@ -420,7 +417,7 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt FilterPropagateResult prune_result; { lock_guard l(stats_lock); - prune_result = filter.CheckStatistics(state.current->stats.statistics); + prune_result = filter.CheckStatistics(state.current->GetNode().stats.statistics); if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -447,7 +444,7 @@ FilterPropagateResult ColumnData::CheckZonemap(TableFilter &filter) { return filter.CheckStatistics(stats->statistics); } -unique_ptr ColumnData::GetStatistics() { +unique_ptr ColumnData::GetStatistics() const { if (!stats) { throw InternalException("ColumnData::GetStatistics called on a column without stats"); } @@ -475,31 +472,34 @@ void ColumnData::InitializeAppend(ColumnAppendState &state) { auto l = data.Lock(); if (data.IsEmpty(l)) { // no segments yet, append an empty segment - AppendTransientSegment(l, start); + AppendTransientSegment(l, 0); } auto segment = data.GetLastSegment(l); - if (segment->segment_type == ColumnSegmentType::PERSISTENT || !segment->GetCompressionFunction().init_append) { + auto &last_segment = segment->GetNode(); + if (last_segment.segment_type == ColumnSegmentType::PERSISTENT || + !last_segment.GetCompressionFunction().init_append) { // we cannot append to this segment - append a new segment - auto total_rows = segment->start + segment->count; + auto total_rows = segment->GetRowStart() + last_segment.count; AppendTransientSegment(l, total_rows); state.current = data.GetLastSegment(l); } else { state.current = segment; } - - D_ASSERT(state.current->segment_type == ColumnSegmentType::TRANSIENT); - state.current->InitializeAppend(state); - D_ASSERT(state.current->GetCompressionFunction().append); + auto &append_segment = state.current->GetNode(); + D_ASSERT(append_segment.segment_type == ColumnSegmentType::TRANSIENT); + append_segment.InitializeAppend(state); + D_ASSERT(append_segment.GetCompressionFunction().append); } void ColumnData::AppendData(BaseStatistics &append_stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t append_count) { idx_t offset = 0; - this->count += append_count; while (true) { // append the data from the vector - idx_t copied_elements = state.current->Append(state, vdata, offset, append_count); - append_stats.Merge(state.current->stats.statistics); + auto &append_segment = state.current->GetNode(); + idx_t copied_elements = append_segment.Append(state, vdata, offset, append_count); + this->count += copied_elements; + append_stats.Merge(append_segment.stats.statistics); if (copied_elements == append_count) { // finished copying everything break; @@ -508,99 +508,114 @@ void ColumnData::AppendData(BaseStatistics &append_stats, ColumnAppendState &sta // we couldn't fit everything we wanted in the current column segment, create a new one { auto l = data.Lock(); - AppendTransientSegment(l, state.current->start + state.current->count); + AppendTransientSegment(l, state.current->GetRowStart() + append_segment.count); state.current = data.GetLastSegment(l); - state.current->InitializeAppend(state); + state.current->GetNode().InitializeAppend(state); } offset += copied_elements; append_count -= copied_elements; } } -void ColumnData::RevertAppend(row_t start_row_p) { - idx_t start_row = NumericCast(start_row_p); +void ColumnData::RevertAppend(row_t new_count_p) { + idx_t new_count = NumericCast(new_count_p); auto l = data.Lock(); // check if this row is in the segment tree at all - auto last_segment = data.GetLastSegment(l); - if (!last_segment) { + auto last_segment_node = data.GetLastSegment(l); + if (!last_segment_node) { return; } - if (start_row >= last_segment->start + last_segment->count) { + auto &last_segment = last_segment_node->GetNode(); + if (new_count >= last_segment_node->GetRowStart() + last_segment.count) { // the start row is equal to the final portion of the column data: nothing was ever appended here - D_ASSERT(start_row == last_segment->start + last_segment->count); + D_ASSERT(new_count == last_segment_node->GetRowStart() + last_segment.count); return; } // find the segment index that the current row belongs to - idx_t segment_index = data.GetSegmentIndex(l, start_row); + idx_t segment_index = data.GetSegmentIndex(l, new_count); auto segment = data.GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); - if (segment->start == start_row) { + if (segment->GetRowStart() == new_count) { // we are truncating exactly this segment - erase it entirely data.EraseSegments(l, segment_index); + if (segment_index > 0) { + // if we have a previous segment, we need to update the next pointer + auto previous_segment = data.GetSegmentByIndex(l, UnsafeNumericCast(segment_index - 1)); + previous_segment->SetNext(nullptr); + } } else { // we need to truncate within the segment // remove any segments AFTER this segment: they should be deleted entirely data.EraseSegments(l, segment_index + 1); - auto &transient = *segment; + auto &transient = segment->GetNode(); D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); - segment->next = nullptr; - transient.RevertAppend(start_row); + segment->SetNext(nullptr); + transient.RevertAppend(new_count - segment->GetRowStart()); } - this->count = start_row - this->start; + this->count = new_count; } idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + if (UnsafeNumericCast(row_id) > count) { + throw InternalException("ColumnData::Fetch - row_id out of range"); + } D_ASSERT(row_id >= 0); - D_ASSERT(NumericCast(row_id) >= start); // perform the fetch within the segment - state.row_index = - start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); - state.current = data.GetSegment(state.row_index); - state.internal_index = state.current->start; + state.offset_in_column = UnsafeNumericCast(row_id) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE; + state.current = data.GetSegment(state.offset_in_column); + state.internal_index = state.current->GetRowStart(); return ScanVector(state, result, STANDARD_VECTOR_SIZE, ScanVectorType::SCAN_FLAT_VECTOR); } void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { + if (UnsafeNumericCast(row_id) > count) { + throw InternalException("ColumnData::FetchRow - row_id out of range"); + } auto segment = data.GetSegment(UnsafeNumericCast(row_id)); // now perform the fetch within the segment - segment->FetchRow(state, row_id, result, result_idx); + auto index_in_segment = row_id - UnsafeNumericCast(segment->GetRowStart()); + segment->GetNode().FetchRow(state, index_in_segment, result, result_idx); // merge any updates made to this row FetchUpdateRow(transaction, row_id, result, result_idx); } -idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector) { - auto fetch_count = ColumnData::Fetch(state, row_ids[0], base_vector); +idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector, idx_t row_group_start) { + if (row_ids[0] < UnsafeNumericCast(row_group_start)) { + throw InternalException("ColumnData::FetchUpdateData out of range"); + } + auto fetch_count = ColumnData::Fetch(state, row_ids[0] - UnsafeNumericCast(row_group_start), base_vector); base_vector.Flatten(fetch_count); return fetch_count; } -void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, idx_t row_group_start) { Vector base_vector(type); - ColumnScanState state; - FetchUpdateData(state, row_ids, base_vector); + ColumnScanState state(nullptr); + FetchUpdateData(state, row_ids, base_vector, row_group_start); - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector, + row_group_start); } -void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) { +void ColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth, + idx_t row_group_start) { // this method should only be called at the end of the path in the base column case D_ASSERT(depth >= column_path.size()); - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count, row_group_start); } void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { - const auto block_size = block_manager.GetBlockSize(); const auto type_size = GetTypeIdSize(type.InternalType()); auto vector_segment_size = block_size; - if (start_row == NumericCast(MAX_ROW_ID)) { + if (data_type == ColumnDataType::INITIAL_TRANSACTION_LOCAL && start_row == 0) { #if STANDARD_VECTOR_SIZE < 1024 vector_segment_size = 1024 * type_size; #else @@ -616,8 +631,7 @@ void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { auto &config = DBConfig::GetConfig(db); auto function = config.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, type.InternalType()); - auto new_segment = - ColumnSegment::CreateTransientSegment(db, *function, type, start_row, segment_size, block_manager); + auto new_segment = ColumnSegment::CreateTransientSegment(db, *function, type, segment_size, block_manager); AppendSegment(l, std::move(new_segment)); } @@ -641,24 +655,25 @@ void ColumnData::AppendSegment(SegmentLock &l, unique_ptr segment data.AppendSegment(l, std::move(segment)); } -void ColumnData::CommitDropColumn() { +void ColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { for (auto &segment_p : data.Segments()) { auto &segment = segment_p; - segment.CommitDropSegment(); + segment.VisitBlockIds(visitor); } } -unique_ptr ColumnData::CreateCheckpointState(RowGroup &row_group, +unique_ptr ColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { return make_uniq(row_group, *this, partial_block_manager); } -void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector) { +void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, + Vector &scan_vector) const { if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < count; i++) { ColumnFetchState fetch_state; - segment.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), scan_vector, i); + fetch_state.row_group = state.parent->row_group; + segment.FetchRow(fetch_state, UnsafeNumericCast(state.offset_in_column + i), scan_vector, i); } } else { segment.Scan(state, count, scan_vector, 0, ScanVectorType::SCAN_FLAT_VECTOR); @@ -666,18 +681,19 @@ void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, if (updates) { D_ASSERT(scan_vector.GetVectorType() == VectorType::FLAT_VECTOR); - updates->FetchCommittedRange(state.row_index - row_group_start, count, scan_vector); + updates->FetchCommittedRange(state.offset_in_column, count, scan_vector); } } -unique_ptr ColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { +unique_ptr ColumnData::Checkpoint(const RowGroup &row_group, + ColumnCheckpointInfo &checkpoint_info) { // scan the segments of the column data // set up the checkpoint state - auto checkpoint_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = CreateCheckpointState(row_group, partial_block_manager); checkpoint_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); - auto &nodes = data.ReferenceSegments(); - if (nodes.empty()) { + if (!data.GetRootSegment()) { // empty table: flush the empty list return checkpoint_state; } @@ -699,6 +715,7 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist this->count = 0; for (auto &data_pointer : column_data.pointers) { // Update the count and statistics + data_pointer.row_start = count; this->count += data_pointer.tuple_count; // Merge the statistics. If this is a child column, the target_stats reference will point into the parents stats @@ -709,8 +726,8 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist // create a persistent segment auto segment = ColumnSegment::CreatePersistentSegment( GetDatabase(), block_manager, data_pointer.block_pointer.block_id, data_pointer.block_pointer.offset, type, - data_pointer.row_start, data_pointer.tuple_count, data_pointer.compression_type, - std::move(data_pointer.statistics), std::move(data_pointer.segment_state)); + data_pointer.tuple_count, data_pointer.compression_type, std::move(data_pointer.statistics), + std::move(data_pointer.segment_state)); auto l = data.Lock(); AppendSegment(l, std::move(segment)); @@ -728,17 +745,20 @@ bool ColumnData::IsPersistent() { vector ColumnData::GetDataPointers() { vector pointers; + idx_t row_start = 0; for (auto &segment : data.Segments()) { - pointers.push_back(segment.GetDataPointer()); + pointers.push_back(segment.GetDataPointer(row_start)); + row_start += segment.count; } return pointers; } -PersistentColumnData::PersistentColumnData(PhysicalType physical_type_p) : physical_type(physical_type_p) { +PersistentColumnData::PersistentColumnData(const LogicalType &logical_type) + : physical_type(logical_type.InternalType()), logical_type_id(logical_type.id()) { } -PersistentColumnData::PersistentColumnData(PhysicalType physical_type, vector pointers_p) - : physical_type(physical_type), pointers(std::move(pointers_p)) { +PersistentColumnData::PersistentColumnData(const LogicalType &logical_type, vector pointers_p) + : physical_type(logical_type.InternalType()), logical_type_id(logical_type.id()), pointers(std::move(pointers_p)) { D_ASSERT(!pointers.empty()); } @@ -756,6 +776,22 @@ void PersistentColumnData::Serialize(Serializer &serializer) const { return; } serializer.WriteProperty(101, "validity", child_columns[0]); + + if (logical_type_id == LogicalTypeId::VARIANT) { + D_ASSERT(physical_type == PhysicalType::STRUCT); + D_ASSERT(child_columns.size() == 2 || child_columns.size() == 3); + + auto unshredded_type = VariantShredding::GetUnshreddedType(); + serializer.WriteProperty(102, "unshredded", child_columns[1]); + + if (child_columns.size() == 3) { + D_ASSERT(variant_shredded_type.id() == LogicalTypeId::STRUCT); + serializer.WriteProperty(115, "shredded_type", variant_shredded_type); + serializer.WriteProperty(120, "shredded", child_columns[2]); + } + return; + } + if (physical_type == PhysicalType::ARRAY || physical_type == PhysicalType::LIST) { D_ASSERT(child_columns.size() == 2); serializer.WriteProperty(102, "child_column", child_columns[1]); @@ -775,13 +811,32 @@ void PersistentColumnData::DeserializeField(Deserializer &deserializer, field_id PersistentColumnData PersistentColumnData::Deserialize(Deserializer &deserializer) { auto &type = deserializer.Get(); auto physical_type = type.InternalType(); - PersistentColumnData result(physical_type); + PersistentColumnData result(type); deserializer.ReadPropertyWithDefault(100, "data_pointers", static_cast &>(result.pointers)); if (result.physical_type == PhysicalType::BIT) { // validity: return return result; } result.DeserializeField(deserializer, 101, "validity", LogicalTypeId::VALIDITY); + + if (type.id() == LogicalTypeId::VARIANT) { + auto unshredded_type = VariantShredding::GetUnshreddedType(); + + deserializer.Set(unshredded_type); + result.child_columns.push_back(deserializer.ReadProperty(102, "unshredded")); + deserializer.Unset(); + + auto shredded_type = + deserializer.ReadPropertyWithExplicitDefault(115, "shredded_type", LogicalType()); + if (shredded_type.id() == LogicalTypeId::STRUCT) { + deserializer.Set(shredded_type); + result.child_columns.push_back(deserializer.ReadProperty(120, "shredded")); + deserializer.Unset(); + result.SetVariantShreddedType(shredded_type); + } + return result; + } + switch (physical_type) { case PhysicalType::ARRAY: result.DeserializeField(deserializer, 102, "child_column", ArrayType::GetChildType(type)); @@ -816,6 +871,12 @@ bool PersistentColumnData::HasUpdates() const { return false; } +void PersistentColumnData::SetVariantShreddedType(const LogicalType &shredded_type) { + D_ASSERT(physical_type == PhysicalType::STRUCT); + D_ASSERT(logical_type_id == LogicalTypeId::VARIANT); + variant_shredded_type = shredded_type; +} + PersistentRowGroupData::PersistentRowGroupData(vector types_p) : types(std::move(types_p)) { } @@ -868,25 +929,14 @@ bool PersistentCollectionData::HasUpdates() const { } PersistentColumnData ColumnData::Serialize() { - PersistentColumnData result(type.InternalType(), GetDataPointers()); + auto result = count ? PersistentColumnData(type, GetDataPointers()) : PersistentColumnData(type); result.has_updates = HasUpdates(); return result; } -void RealignColumnData(PersistentColumnData &column_data, idx_t new_start) { - idx_t current_start = new_start; - for (auto &pointer : column_data.pointers) { - pointer.row_start = current_start; - current_start += pointer.tuple_count; - } - for (auto &child : column_data.child_columns) { - RealignColumnData(child, new_start); - } -} - shared_ptr ColumnData::Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, ReadStream &source, const LogicalType &type) { - auto entry = ColumnData::CreateColumn(block_manager, info, column_index, start_row, type, nullptr); + ReadStream &source, const LogicalType &type) { + auto entry = ColumnData::CreateColumn(block_manager, info, column_index, type); // deserialize the persistent column data BinaryDeserializer deserializer(source); @@ -901,15 +951,23 @@ shared_ptr ColumnData::Deserialize(BlockManager &block_manager, Data deserializer.Unset(); deserializer.End(); - // re-align data segments, in case our start_row has changed - RealignColumnData(persistent_column_data, start_row); - // initialize the column entry->InitializeColumn(persistent_column_data, entry->stats->statistics); return entry; } -void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +struct ListBlockIds : public BlockIdVisitor { + explicit ListBlockIds(vector &block_ids) : block_ids(block_ids) { + } + + void Visit(block_id_t block_id) override { + block_ids.push_back(block_id); + } + + vector &block_ids; +}; + +void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { D_ASSERT(!col_path.empty()); @@ -925,40 +983,45 @@ void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_p // iterate over the segments idx_t segment_idx = 0; - auto segment = data.GetRootSegment(); - while (segment) { + for (auto &segment_node : data.SegmentNodes()) { + auto &segment = segment_node.GetNode(); ColumnSegmentInfo column_info; column_info.row_group_index = row_group_index; column_info.column_id = col_path[0]; column_info.column_path = col_path_str; column_info.segment_idx = segment_idx; column_info.segment_type = type.ToString(); - column_info.segment_start = segment->start; - column_info.segment_count = segment->count; - column_info.compression_type = CompressionTypeToString(segment->GetCompressionFunction().type); + column_info.segment_start = segment_node.GetRowStart(); + column_info.segment_count = segment.count; + column_info.compression_type = CompressionTypeToString(segment.GetCompressionFunction().type); { lock_guard l(stats_lock); - column_info.segment_stats = segment->stats.statistics.ToString(); + column_info.segment_stats = segment.stats.statistics.ToString(); } column_info.has_updates = ColumnData::HasUpdates(); // persistent // block_id // block_offset - if (segment->segment_type == ColumnSegmentType::PERSISTENT) { + if (segment.segment_type == ColumnSegmentType::PERSISTENT) { column_info.persistent = true; - column_info.block_id = segment->GetBlockId(); - column_info.block_offset = segment->GetBlockOffset(); + column_info.block_id = segment.GetBlockId(); + column_info.block_offset = segment.GetBlockOffset(); } else { column_info.persistent = false; + column_info.block_id = INVALID_BLOCK; + column_info.block_offset = 0; } - auto &compression_function = segment->GetCompressionFunction(); - auto segment_state = segment->GetSegmentState(); + auto &compression_function = segment.GetCompressionFunction(); + auto segment_state = segment.GetSegmentState(); if (segment_state) { column_info.segment_info = segment_state->GetSegmentInfo(); - column_info.additional_blocks = segment_state->GetAdditionalBlocks(); + if (compression_function.visit_block_ids) { + ListBlockIds list_block_ids(column_info.additional_blocks); + compression_function.visit_block_ids(segment, list_block_ids); + } } if (compression_function.get_segment_info) { - auto segment_info = compression_function.get_segment_info(*segment); + auto segment_info = compression_function.get_segment_info(context, segment); vector sinfo; for (auto &item : segment_info) { auto &mode = item.first; @@ -970,13 +1033,11 @@ void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_p result.emplace_back(column_info); segment_idx++; - segment = data.GetNextSegment(segment); } } void ColumnData::Verify(RowGroup &parent) { #ifdef DEBUG - D_ASSERT(this->start == parent.start); data.Verify(); if (type.InternalType() == PhysicalType::STRUCT || type.InternalType() == PhysicalType::ARRAY) { // structs and fixed size lists don't have segments @@ -984,46 +1045,35 @@ void ColumnData::Verify(RowGroup &parent) { return; } idx_t current_index = 0; - idx_t current_start = this->start; + idx_t current_start = 0; idx_t total_count = 0; - for (auto &segment : data.Segments()) { - D_ASSERT(segment.index == current_index); - D_ASSERT(segment.start == current_start); - current_start += segment.count; - total_count += segment.count; + for (auto &segment : data.SegmentNodes()) { + D_ASSERT(segment.GetIndex() == current_index); + D_ASSERT(segment.GetRowStart() == current_start); + current_start += segment.GetNode().count; + total_count += segment.GetNode().count; current_index++; } D_ASSERT(this->count == total_count); #endif } -template -static RET CreateColumnInternal(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - const LogicalType &type, optional_ptr parent) { +shared_ptr ColumnData::CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + const LogicalType &type, ColumnDataType data_type, + optional_ptr parent) { + if (type.id() == LogicalTypeId::VARIANT) { + return make_shared_ptr(block_manager, info, column_index, type, data_type, parent); + } if (type.InternalType() == PhysicalType::STRUCT) { - return OP::template Create(block_manager, info, column_index, start_row, type, parent); + return make_shared_ptr(block_manager, info, column_index, type, data_type, parent); } else if (type.InternalType() == PhysicalType::LIST) { - return OP::template Create(block_manager, info, column_index, start_row, type, parent); + return make_shared_ptr(block_manager, info, column_index, type, data_type, parent); } else if (type.InternalType() == PhysicalType::ARRAY) { - return OP::template Create(block_manager, info, column_index, start_row, type, parent); + return make_shared_ptr(block_manager, info, column_index, type, data_type, parent); } else if (type.id() == LogicalTypeId::VALIDITY) { - return OP::template Create(block_manager, info, column_index, start_row, *parent); + return make_shared_ptr(block_manager, info, column_index, data_type, parent); } - return OP::template Create(block_manager, info, column_index, start_row, type, parent); -} - -shared_ptr ColumnData::CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, const LogicalType &type, - optional_ptr parent) { - return CreateColumnInternal, SharedConstructor>(block_manager, info, column_index, start_row, - type, parent); -} - -unique_ptr ColumnData::CreateColumnUnique(BlockManager &block_manager, DataTableInfo &info, - idx_t column_index, idx_t start_row, const LogicalType &type, - optional_ptr parent) { - return CreateColumnInternal, UniqueConstructor>(block_manager, info, column_index, start_row, - type, parent); + return make_shared_ptr(block_manager, info, column_index, type, data_type, parent); } } // namespace duckdb diff --git a/src/duckdb/src/storage/table/column_data_checkpointer.cpp b/src/duckdb/src/storage/table/column_data_checkpointer.cpp index 68c35f842..1d7fb4cfe 100644 --- a/src/duckdb/src/storage/table/column_data_checkpointer.cpp +++ b/src/duckdb/src/storage/table/column_data_checkpointer.cpp @@ -31,7 +31,7 @@ ColumnData &ColumnDataCheckpointData::GetColumnData() { return *col_data; } -RowGroup &ColumnDataCheckpointData::GetRowGroup() { +const RowGroup &ColumnDataCheckpointData::GetRowGroup() { return *row_group; } @@ -49,7 +49,7 @@ static Vector CreateIntermediateVector(vector> D_ASSERT(!states.empty()); auto &first_state = states[0]; - auto &col_data = first_state.get().column_data; + auto &col_data = first_state.get().original_column; auto &type = col_data.type; if (type.id() == LogicalTypeId::VALIDITY) { return Vector(LogicalType::BOOLEAN, true, /* initialize_to_zero = */ true); @@ -61,16 +61,15 @@ static Vector CreateIntermediateVector(vector> } ColumnDataCheckpointer::ColumnDataCheckpointer(vector> &checkpoint_states, - StorageManager &storage_manager, RowGroup &row_group, + StorageManager &storage_manager, const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) : checkpoint_states(checkpoint_states), storage_manager(storage_manager), row_group(row_group), intermediate(CreateIntermediateVector(checkpoint_states)), checkpoint_info(checkpoint_info) { - auto &db = storage_manager.GetDatabase(); auto &config = DBConfig::GetConfig(db); compression_functions.resize(checkpoint_states.size()); for (idx_t i = 0; i < checkpoint_states.size(); i++) { - auto &col_data = checkpoint_states[i].get().column_data; + auto &col_data = checkpoint_states[i].get().original_column; auto to_add = config.GetCompressionFunctions(col_data.type.InternalType()); auto &functions = compression_functions[i]; for (auto &func : to_add) { @@ -82,23 +81,22 @@ ColumnDataCheckpointer::ColumnDataCheckpointer(vector &callback) { Vector scan_vector(intermediate.GetType(), nullptr); auto &first_state = checkpoint_states[0]; - auto &col_data = first_state.get().column_data; - auto &nodes = col_data.data.ReferenceSegments(); + auto &col_data = first_state.get().original_column; // TODO: scan all the nodes from all segments, no need for CheckpointScan to virtualize this I think.. - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto &segment = *nodes[segment_idx].node; - ColumnScanState scan_state; - scan_state.current = &segment; + for (auto &segment_node : col_data.data.SegmentNodes()) { + auto &segment = segment_node.GetNode(); + ColumnScanState scan_state(nullptr); + scan_state.current = segment_node; segment.InitializeScan(scan_state); for (idx_t base_row_index = 0; base_row_index < segment.count; base_row_index += STANDARD_VECTOR_SIZE) { scan_vector.Reference(intermediate); idx_t count = MinValue(segment.count - base_row_index, STANDARD_VECTOR_SIZE); - scan_state.row_index = segment.start + base_row_index; + scan_state.offset_in_column = segment_node.GetRowStart() + base_row_index; - col_data.CheckpointScan(segment, scan_state, row_group.start, count, scan_vector); + col_data.CheckpointScan(segment, scan_state, count, scan_vector); callback(scan_vector, count); } } @@ -109,9 +107,10 @@ CompressionType ForceCompression(StorageManager &storage_manager, CompressionType compression_type) { // One of the force_compression flags has been set // check if this compression method is available - // if (CompressionTypeIsDeprecated(compression_type, storage_manager)) { + // auto compression_availability_result = CompressionTypeIsAvailable(compression_type, storage_manager); + // if (!compression_availability_result.IsAvailable()) { // throw InvalidInputException("The forced compression method (%s) is not available in the current storage - // version", CompressionTypeToString(compression_type)); + // version", CompressionTypeToString(compression_type)); //} bool found = false; @@ -143,14 +142,10 @@ CompressionType ForceCompression(StorageManager &storage_manager, void ColumnDataCheckpointer::InitAnalyze() { analyze_states.resize(checkpoint_states.size()); for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } - auto &functions = compression_functions[i]; auto &states = analyze_states[i]; auto &checkpoint_state = checkpoint_states[i]; - auto &coldata = checkpoint_state.get().column_data; + auto &coldata = checkpoint_state.get().GetResultColumn(); states.resize(functions.size()); for (idx_t j = 0; j < functions.size(); j++) { auto &func = functions[j]; @@ -185,10 +180,6 @@ vector ColumnDataCheckpointer::DetectBestCompressionMet // scan over all the segments and run the analyze step ScanSegments([&](Vector &scan_vector, idx_t count) { for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } - auto &functions = compression_functions[i]; auto &states = analyze_states[i]; for (idx_t j = 0; j < functions.size(); j++) { @@ -210,9 +201,6 @@ vector ColumnDataCheckpointer::DetectBestCompressionMet result.resize(checkpoint_states.size()); for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } auto &functions = compression_functions[i]; auto &states = analyze_states[i]; auto &forced_method = forced_methods[i]; @@ -253,7 +241,7 @@ vector ColumnDataCheckpointer::DetectBestCompressionMet } auto &checkpoint_state = checkpoint_states[i]; - auto &col_data = checkpoint_state.get().column_data; + auto &col_data = checkpoint_state.get().GetResultColumn(); if (!chosen_state) { throw FatalException("No suitable compression/storage method found to store column of type %s", col_data.type.ToString()); @@ -261,32 +249,38 @@ vector ColumnDataCheckpointer::DetectBestCompressionMet D_ASSERT(compression_idx != DConstants::INVALID_INDEX); auto &best_function = *functions[compression_idx]; - DUCKDB_LOG_INFO(db, "ColumnDataCheckpointer FinalAnalyze(%s) result for %s.%s.%d(%s): %d", - EnumUtil::ToString(best_function.type), col_data.info.GetSchemaName(), - col_data.info.GetTableName(), col_data.column_index, col_data.type.ToString(), best_score); + DUCKDB_LOG_TRACE(db, "ColumnDataCheckpointer FinalAnalyze(%s) result for %s.%s.%d(%s): %d", + EnumUtil::ToString(best_function.type), col_data.info.GetSchemaName(), + col_data.info.GetTableName(), col_data.column_index, col_data.type.ToString(), best_score); result[i] = CheckpointAnalyzeResult(std::move(chosen_state), best_function); } return result; } +struct CheckpointBlockIdDropper : public BlockIdVisitor { + explicit CheckpointBlockIdDropper(BlockManager &manager) : manager(manager) { + } + + void Visit(block_id_t block_id) override { + manager.MarkBlockAsModified(block_id); + } + + BlockManager &manager; +}; + void ColumnDataCheckpointer::DropSegments() { // first we check the current segments // if there are any persistent segments, we will mark their old block ids as modified // since the segments will be rewritten their old on disk data is no longer required for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } - auto &state = checkpoint_states[i]; - auto &col_data = state.get().column_data; - auto &nodes = col_data.data.ReferenceSegments(); + auto &col_data = state.get().original_column; // Drop the segments, as we'll be replacing them with new ones, because there are changes - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - segment->CommitDropSegment(); + CheckpointBlockIdDropper dropper(storage_manager.GetBlockManager()); + for (auto &segment : col_data.data.Segments()) { + segment.VisitBlockIds(dropper); } } } @@ -295,19 +289,13 @@ bool ColumnDataCheckpointer::ValidityCoveredByBasedata(vectorvalidity == CompressionValidity::NO_VALIDITY_REQUIRED; } -void ColumnDataCheckpointer::WriteToDisk() { - DropSegments(); - - // Analyze the candidate functions to select one of them to use for compression +void ColumnDataCheckpointer::WriteToDisk() { // Analyze the candidate functions to select one of them to use for + // compression auto analyze_result = DetectBestCompressionMethod(); if (ValidityCoveredByBasedata(analyze_result)) { D_ASSERT(analyze_result.size() == 2); @@ -324,26 +312,20 @@ void ColumnDataCheckpointer::WriteToDisk() { vector checkpoint_data(checkpoint_states.size()); vector> compression_states(checkpoint_states.size()); for (idx_t i = 0; i < analyze_result.size(); i++) { - if (!has_changes[i]) { - continue; - } auto &analyze_state = analyze_result[i].analyze_state; auto &function = analyze_result[i].function; auto &checkpoint_state = checkpoint_states[i]; - auto &col_data = checkpoint_state.get().column_data; + auto &col_data = checkpoint_state.get().GetResultColumn(); - checkpoint_data[i] = ColumnDataCheckpointData(checkpoint_state, col_data, col_data.GetDatabase(), row_group, - checkpoint_info, storage_manager); + checkpoint_data[i] = + ColumnDataCheckpointData(checkpoint_state, col_data, col_data.GetDatabase(), row_group, storage_manager); compression_states[i] = function->init_compression(checkpoint_data[i], std::move(analyze_state)); } // Scan over the existing segment + changes and compress the data ScanSegments([&](Vector &scan_vector, idx_t count) { for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } auto &function = analyze_result[i].function; auto &compression_state = compression_states[i]; function->compress(*compression_state, scan_vector, count); @@ -352,77 +334,87 @@ void ColumnDataCheckpointer::WriteToDisk() { // Finalize the compression for (idx_t i = 0; i < checkpoint_states.size(); i++) { - if (!has_changes[i]) { - continue; - } auto &function = analyze_result[i].function; auto &compression_state = compression_states[i]; function->compress_finalize(*compression_state); } + + // after we finish checkpointing we can drop this segment + DropSegments(); } bool ColumnDataCheckpointer::HasChanges(ColumnData &col_data) { - return col_data.HasChanges(); + return col_data.HasAnyChanges(); } void ColumnDataCheckpointer::WritePersistentSegments(ColumnCheckpointState &state) { // all segments are persistent and there are no updates // we only need to write the metadata - auto &col_data = state.column_data; - auto nodes = col_data.data.MoveSegments(); - - idx_t current_row = row_group.start; - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->start != current_row) { - string extra_info; - for (auto &s : nodes) { - extra_info += "\n"; - extra_info += StringUtil::Format("Start %d, count %d", s.node->start, s.node->count.load()); - } - const_reference root = col_data; - while (root.get().HasParent()) { - root = root.get().Parent(); - } - throw InternalException( - "Failure in RowGroup::Checkpoint - column data pointer is unaligned with row group " - "start\nRow group start: %d\nRow group count %d\nCurrent row: %d\nSegment start: %d\nColumn index: " - "%d\nColumn type: %s\nRoot type: %s\nTable: %s.%s\nAll segments:%s", - row_group.start, row_group.count.load(), current_row, segment->start, root.get().column_index, - col_data.type, root.get().type, root.get().info.GetSchemaName(), root.get().info.GetTableName(), - extra_info); + auto &col_data = state.original_column; + + optional_idx error_segment_start; + idx_t current_row = 0; + for (auto &segment_node : col_data.data.SegmentNodes()) { + auto &segment = segment_node.GetNode(); + auto segment_start = segment_node.GetRowStart(); + if (segment_start != current_row) { + error_segment_start = segment_start; + break; } - current_row += segment->count; - auto pointer = segment->GetDataPointer(); + auto pointer = segment.GetDataPointer(current_row); + current_row += segment.count; // merge the persistent stats into the global column stats - state.global_stats->Merge(segment->stats.statistics); - - // directly append the current segment to the new tree - state.new_tree.AppendSegment(std::move(nodes[segment_idx].node)); - + state.global_stats->Merge(segment.stats.statistics); state.data_pointers.push_back(std::move(pointer)); } + if (error_segment_start.IsValid()) { + string extra_info; + for (auto &s : col_data.data.SegmentNodes()) { + extra_info += "\n"; + extra_info += StringUtil::Format("Start %d, count %d", s.GetRowStart(), s.GetNode().count.load()); + } + throw InternalException( + "Failure in RowGroup::Checkpoint - column data pointer is unaligned with row group " + "start\nRow group start: %d\nRow group count %d\nCurrent row: %d\nSegment start: %d\nColumn index: " + "%d\nColumn type: %s\nRoot type: %s\nTable: %s.%s\nAll segments:%s", + row_group.count.load(), current_row, error_segment_start.GetIndex(), col_data.column_index, col_data.type, + col_data.type, col_data.info.GetSchemaName(), col_data.info.GetTableName(), extra_info); + } } +struct CheckpointBlockIdMarker : public BlockIdVisitor { + explicit CheckpointBlockIdMarker(BlockManager &manager) : manager(manager) { + } + + void Visit(block_id_t block_id) override { + manager.MarkBlockACheckpointed(block_id); + } + + BlockManager &manager; +}; + void ColumnDataCheckpointer::Checkpoint() { for (idx_t i = 0; i < checkpoint_states.size(); i++) { auto &state = checkpoint_states[i]; - auto &col_data = state.get().column_data; - has_changes.push_back(HasChanges(col_data)); - } - - bool any_has_changes = false; - for (idx_t i = 0; i < has_changes.size(); i++) { - if (has_changes[i]) { - any_has_changes = true; + auto &col_data = state.get().original_column; + if (col_data.HasChanges()) { + has_changes = true; break; } } - if (!any_has_changes) { + + if (!has_changes) { // Nothing has undergone any changes, no need to checkpoint // just move on to finalizing + // mark block ids as checkpointed + CheckpointBlockIdMarker marker(storage_manager.GetBlockManager()); + for (idx_t i = 0; i < checkpoint_states.size(); i++) { + auto &state = checkpoint_states[i]; + auto &col_data = state.get().original_column; + col_data.VisitBlockIds(marker); + } return; } @@ -430,26 +422,15 @@ void ColumnDataCheckpointer::Checkpoint() { } void ColumnDataCheckpointer::FinalizeCheckpoint() { + if (has_changes) { + // something has undergone changes, we rewrote everything + // write the new data - not the old data + return; + } + // no changes - copy over the original columns for (idx_t i = 0; i < checkpoint_states.size(); i++) { auto &state = checkpoint_states[i].get(); - auto &col_data = state.column_data; - if (has_changes[i]) { - // Move the existing segments out of the column data - // they will be destructed at the end of the scope - auto to_delete = col_data.data.MoveSegments(); - } else { - WritePersistentSegments(state); - } - - // reset the compression function - col_data.compression.reset(); - // replace the old tree with the new one - auto new_segments = state.new_tree.MoveSegments(); - auto l = col_data.data.Lock(); - for (auto &new_segment : new_segments) { - col_data.AppendSegment(l, std::move(new_segment.node)); - } - col_data.ClearUpdates(); + WritePersistentSegments(state); } } diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp index 347463fbe..e1739bc8a 100644 --- a/src/duckdb/src/storage/table/column_segment.cpp +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -14,7 +14,9 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table/update_segment.hpp" #include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include @@ -26,11 +28,10 @@ namespace duckdb { unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, block_id_t block_id, idx_t offset, - const LogicalType &type, idx_t start, idx_t count, + const LogicalType &type, idx_t count, CompressionType compression_type, BaseStatistics statistics, unique_ptr segment_state) { - auto &config = DBConfig::GetConfig(db); optional_ptr function; shared_ptr block; @@ -41,20 +42,19 @@ unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstanc } auto segment_size = block_manager.GetBlockSize(); - return make_uniq(db, std::move(block), type, ColumnSegmentType::PERSISTENT, start, count, *function, + return make_uniq(db, std::move(block), type, ColumnSegmentType::PERSISTENT, count, *function, std::move(statistics), block_id, offset, segment_size, std::move(segment_state)); } unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance &db, CompressionFunction &function, - const LogicalType &type, const idx_t start, - const idx_t segment_size, BlockManager &block_manager) { - + const LogicalType &type, const idx_t segment_size, + BlockManager &block_manager) { // Allocate a buffer for the uncompressed segment. auto &buffer_manager = BufferManager::GetBufferManager(db); D_ASSERT(&buffer_manager == &block_manager.buffer_manager); auto block = buffer_manager.RegisterTransientMemory(segment_size, block_manager); - return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, start, 0U, function, + return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, 0U, function, BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0U, segment_size); } @@ -62,15 +62,13 @@ unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance // Construct/Destruct //===--------------------------------------------------------------------===// ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block_p, const LogicalType &type, - const ColumnSegmentType segment_type, const idx_t start, const idx_t count, - CompressionFunction &function_p, BaseStatistics statistics, const block_id_t block_id_p, - const idx_t offset, const idx_t segment_size_p, - const unique_ptr segment_state_p) + const ColumnSegmentType segment_type, const idx_t count, CompressionFunction &function_p, + BaseStatistics statistics, const block_id_t block_id_p, const idx_t offset, + const idx_t segment_size_p, const unique_ptr segment_state_p) - : SegmentBase(start, count), db(db), type(type), type_size(GetTypeIdSize(type.InternalType())), + : SegmentBase(count), db(db), type(type), type_size(GetTypeIdSize(type.InternalType())), segment_type(segment_type), stats(std::move(statistics)), block(std::move(block_p)), function(function_p), block_id(block_id_p), offset(offset), segment_size(segment_size_p) { - if (function.get().init_segment) { segment_state = function.get().init_segment(*this, block_id, segment_state_p.get()); } @@ -79,13 +77,11 @@ ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block D_ASSERT(!block || segment_size <= GetBlockManager().GetBlockSize()); } -ColumnSegment::ColumnSegment(ColumnSegment &other, const idx_t start) - - : SegmentBase(start, other.count.load()), db(other.db), type(std::move(other.type)), +ColumnSegment::ColumnSegment(ColumnSegment &other) + : SegmentBase(other.count.load()), db(other.db), type(std::move(other.type)), type_size(other.type_size), segment_type(other.segment_type), stats(std::move(other.stats)), block(std::move(other.block)), function(other.function), block_id(other.block_id), offset(other.offset), segment_size(other.segment_size), segment_state(std::move(other.segment_state)) { - // For constant segments (CompressionType::COMPRESSION_CONSTANT) the block is a nullptr. D_ASSERT(!block || segment_size <= GetBlockManager().GetBlockSize()); } @@ -109,7 +105,7 @@ void ColumnSegment::InitializePrefetch(PrefetchState &prefetch_state, ColumnScan } void ColumnSegment::InitializeScan(ColumnScanState &state) { - state.scan_state = function.get().init_scan(*this); + state.scan_state = function.get().init_scan(state.context, *this); } void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, @@ -141,8 +137,8 @@ void ColumnSegment::Filter(ColumnScanState &state, idx_t scan_count, Vector &res } void ColumnSegment::Skip(ColumnScanState &state) { - function.get().skip(*this, state, state.row_index - state.internal_index); - state.internal_index = state.row_index; + function.get().skip(*this, state, state.offset_in_column - state.internal_index); + state.internal_index = state.offset_in_column; } void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result) { @@ -157,8 +153,10 @@ void ColumnSegment::ScanPartial(ColumnScanState &state, idx_t scan_count, Vector // Fetch //===--------------------------------------------------------------------===// void ColumnSegment::FetchRow(ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - function.get().fetch_row(*this, state, UnsafeNumericCast(UnsafeNumericCast(row_id) - this->start), - result, result_idx); + if (UnsafeNumericCast(row_id) > count) { + throw InternalException("ColumnSegment::FetchRow - row_id out of range for segment"); + } + function.get().fetch_row(*this, state, row_id, result, result_idx); } //===--------------------------------------------------------------------===// @@ -210,12 +208,12 @@ idx_t ColumnSegment::FinalizeAppend(ColumnAppendState &state) { return result_count; } -void ColumnSegment::RevertAppend(idx_t start_row) { +void ColumnSegment::RevertAppend(idx_t new_count) { D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); if (function.get().revert_append) { - function.get().revert_append(*this, start_row); + function.get().revert_append(*this, new_count); } - this->count = start_row - this->start; + this->count = new_count; } //===--------------------------------------------------------------------===// @@ -242,7 +240,9 @@ void ColumnSegment::ConvertToPersistent(QueryContext context, optional_ptr block_p, uint32_t offset_p) block = std::move(block_p); } -DataPointer ColumnSegment::GetDataPointer() { +DataPointer ColumnSegment::GetDataPointer(idx_t row_start) { if (segment_type != ColumnSegmentType::PERSISTENT) { throw InternalException("Attempting to call ColumnSegment::GetDataPointer on a transient segment"); } @@ -266,7 +266,7 @@ DataPointer ColumnSegment::GetDataPointer() { DataPointer pointer(stats.statistics.Copy()); pointer.block_pointer.block_id = GetBlockId(); pointer.block_pointer.offset = NumericCast(GetBlockOffset()); - pointer.row_start = start; + pointer.row_start = row_start; pointer.tuple_count = count; pointer.compression_type = function.get().type; if (function.get().serialize_state) { @@ -278,12 +278,12 @@ DataPointer ColumnSegment::GetDataPointer() { //===--------------------------------------------------------------------===// // Drop Segment //===--------------------------------------------------------------------===// -void ColumnSegment::CommitDropSegment() { +void ColumnSegment::VisitBlockIds(BlockIdVisitor &visitor) const { if (block_id != INVALID_BLOCK) { - GetBlockManager().MarkBlockAsModified(block_id); + visitor.Visit(block_id); } - if (function.get().cleanup_state) { - function.get().cleanup_state(*this); + if (function.get().visit_block_ids) { + function.get().visit_block_ids(*this, visitor); } } @@ -412,7 +412,8 @@ idx_t ColumnSegment::FilterSelection(SelectionVector &sel, Vector &vector, Unifi idx_t &approved_tuple_count) { switch (filter.filter_type) { case TableFilterType::OPTIONAL_FILTER: { - return scan_count; + auto &opt_filter = filter.Cast(); + return opt_filter.FilterSelection(sel, vector, vdata, filter_state, scan_count, approved_tuple_count); } case TableFilterType::CONJUNCTION_OR: { // similar to the CONJUNCTION_AND, but we need to take care of the SelectionVectors (OR all of them) @@ -559,6 +560,11 @@ idx_t ColumnSegment::FilterSelection(SelectionVector &sel, Vector &vector, Unifi return FilterSelection(sel, *child_vec, child_data, *struct_filter.child_filter, filter_state, scan_count, approved_tuple_count); } + case TableFilterType::BLOOM_FILTER: { + auto &bloom_filter = filter.Cast(); + auto &state = filter_state.Cast(); + return bloom_filter.Filter(vector, sel, approved_tuple_count, state); + } case TableFilterType::EXPRESSION_FILTER: { auto &state = filter_state.Cast(); SelectionVector result_sel(approved_tuple_count); diff --git a/src/duckdb/src/storage/table/in_memory_checkpoint.cpp b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp index 31a0170f3..254bff1e0 100644 --- a/src/duckdb/src/storage/table/in_memory_checkpoint.cpp +++ b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp @@ -9,10 +9,10 @@ namespace duckdb { // In-Memory Checkpoint Writer //===--------------------------------------------------------------------===// InMemoryCheckpointer::InMemoryCheckpointer(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, - StorageManager &storage_manager, CheckpointType checkpoint_type) + StorageManager &storage_manager, CheckpointOptions options_p) : CheckpointWriter(db), context(context.GetClientContext()), partial_block_manager(context, block_manager, PartialBlockType::IN_MEMORY_CHECKPOINT), - storage_manager(storage_manager), checkpoint_type(checkpoint_type) { + storage_manager(storage_manager), options(options_p) { } void InMemoryCheckpointer::CreateCheckpoint() { @@ -37,7 +37,7 @@ void InMemoryCheckpointer::CreateCheckpoint() { WriteTable(table, serializer); } - storage_manager.ResetInMemoryChange(); + storage_manager.SetWALSize(0); } MetadataWriter &InMemoryCheckpointer::GetMetadataWriter() { @@ -66,8 +66,8 @@ InMemoryRowGroupWriter::InMemoryRowGroupWriter(TableCatalogEntry &table, Partial : RowGroupWriter(table, partial_block_manager), checkpoint_manager(checkpoint_manager) { } -CheckpointType InMemoryRowGroupWriter::GetCheckpointType() const { - return checkpoint_manager.GetCheckpointType(); +CheckpointOptions InMemoryRowGroupWriter::GetCheckpointOptions() const { + return checkpoint_manager.GetCheckpointOptions(); } WriteStream &InMemoryRowGroupWriter::GetPayloadWriter() { @@ -98,8 +98,11 @@ unique_ptr InMemoryTableDataWriter::GetRowGroupWriter(RowGroup & return make_uniq(table, checkpoint_manager.GetPartialBlockManager(), checkpoint_manager); } -CheckpointType InMemoryTableDataWriter::GetCheckpointType() const { - return checkpoint_manager.GetCheckpointType(); +void InMemoryTableDataWriter::FlushPartialBlocks() { +} + +CheckpointOptions InMemoryTableDataWriter::GetCheckpointOptions() const { + return checkpoint_manager.GetCheckpointOptions(); } MetadataManager &InMemoryTableDataWriter::GetMetadataManager() { @@ -109,7 +112,7 @@ MetadataManager &InMemoryTableDataWriter::GetMetadataManager() { InMemoryPartialBlock::InMemoryPartialBlock(ColumnData &data, ColumnSegment &segment, PartialBlockState state, BlockManager &block_manager) : PartialBlock(state, block_manager, segment.block) { - AddSegmentToTail(data, segment, 0); + InMemoryPartialBlock::AddSegmentToTail(data, segment, 0); } InMemoryPartialBlock::~InMemoryPartialBlock() { diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index 7685d16ca..28f93ec8f 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -8,20 +8,22 @@ namespace duckdb { -ListColumnData::ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, - LogicalType type_p, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), - validity(block_manager, info, 0, start_row, *this) { +ListColumnData::ListColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type_p, + ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, std::move(type_p), data_type, parent) { D_ASSERT(type.InternalType() == PhysicalType::LIST); - auto &child_type = ListType::GetChildType(type); - // the child column, with column index 1 (0 is the validity mask) - child_column = ColumnData::CreateColumnUnique(block_manager, info, 1, start_row, child_type, this); + if (data_type != ColumnDataType::CHECKPOINT_TARGET) { + auto &child_type = ListType::GetChildType(type); + validity = make_shared_ptr(block_manager, info, 0, *this); + // the child column, with column index 1 (0 is the validity mask) + child_column = CreateColumn(block_manager, info, 1, child_type, data_type, this); + } } -void ListColumnData::SetStart(idx_t new_start) { - ColumnData::SetStart(new_start); - child_column->SetStart(new_start); - validity.SetStart(new_start); +void ListColumnData::SetDataType(ColumnDataType data_type) { + ColumnData::SetDataType(data_type); + child_column->SetDataType(data_type); + validity->SetDataType(data_type); } FilterPropagateResult ListColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { @@ -31,7 +33,7 @@ FilterPropagateResult ListColumnData::CheckZonemap(ColumnScanState &state, Table void ListColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { ColumnData::InitializePrefetch(prefetch_state, scan_state, rows); - validity.InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); + validity->InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); // we can't know how many rows we need to prefetch for the child of this list without looking at the actual data // we make an estimation by looking at how many rows the child column has versus this column @@ -48,7 +50,7 @@ void ListColumnData::InitializeScan(ColumnScanState &state) { // initialize the validity segment D_ASSERT(state.child_states.size() == 2); - validity.InitializeScan(state.child_states[0]); + validity->InitializeScan(state.child_states[0]); // initialize the child scan child_column->InitializeScan(state.child_states[1]); @@ -58,7 +60,8 @@ uint64_t ListColumnData::FetchListOffset(idx_t row_idx) { auto segment = data.GetSegment(row_idx); ColumnFetchState fetch_state; Vector result(LogicalType::UBIGINT, 1); - segment->FetchRow(fetch_state, UnsafeNumericCast(row_idx), result, 0U); + auto index_in_segment = UnsafeNumericCast(row_idx - segment->GetRowStart()); + segment->GetNode().FetchRow(fetch_state, index_in_segment, result, 0U); // initialize the child scan with the required offset return FlatVector::GetData(result)[0]; @@ -73,13 +76,13 @@ void ListColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_ // initialize the validity segment D_ASSERT(state.child_states.size() == 2); - validity.InitializeScanWithOffset(state.child_states[0], row_idx); + validity->InitializeScanWithOffset(state.child_states[0], row_idx); // we need to read the list at position row_idx to get the correct row offset of the child - auto child_offset = row_idx == start ? 0 : FetchListOffset(row_idx - 1); + auto child_offset = FetchListOffset(row_idx - 1); D_ASSERT(child_offset <= child_column->GetMaxEntry()); if (child_offset < child_column->GetMaxEntry()) { - child_column->InitializeScanWithOffset(state.child_states[1], start + child_offset); + child_column->InitializeScanWithOffset(state.child_states[1], child_offset); } state.last_offset = child_offset; } @@ -107,7 +110,7 @@ idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t co Vector offset_vector(LogicalType::UBIGINT, count); idx_t scan_count = ScanVector(state, offset_vector, count, ScanVectorType::SCAN_FLAT_VECTOR); D_ASSERT(scan_count > 0); - validity.ScanCount(state.child_states[0], result, count); + validity->ScanCount(state.child_states[0], result, count); UnifiedVectorFormat offsets; offset_vector.ToUnifiedFormat(scan_count, offsets); @@ -133,7 +136,7 @@ idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t co auto &child_entry = ListVector::GetEntry(result); if (child_entry.GetType().InternalType() != PhysicalType::STRUCT && child_entry.GetType().InternalType() != PhysicalType::ARRAY && - state.child_states[1].row_index + child_scan_count > child_column->start + child_column->GetMaxEntry()) { + state.child_states[1].offset_in_column + child_scan_count > child_column->GetMaxEntry()) { throw InternalException("ListColumnData::ScanCount - internal list scan offset is out of range"); } child_column->ScanCount(state.child_states[1], child_entry, child_scan_count); @@ -146,7 +149,7 @@ idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t co void ListColumnData::Skip(ColumnScanState &state, idx_t count) { // skip inside the validity segment - validity.Skip(state.child_states[0], count); + validity->Skip(state.child_states[0], count); // we need to read the list entries/offsets to figure out how much to skip // note that we only need to read the first and last entry @@ -175,7 +178,7 @@ void ListColumnData::InitializeAppend(ColumnAppendState &state) { // initialize the validity append ColumnAppendState validity_append_state; - validity.InitializeAppend(validity_append_state); + validity->InitializeAppend(validity_append_state); state.child_appends.push_back(std::move(validity_append_state)); // initialize the child column append @@ -245,14 +248,14 @@ void ListColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vec ColumnData::AppendData(stats, state, vdata, count); // append the validity data vdata.validity = append_mask; - validity.AppendData(stats, state.child_appends[0], vdata, count); + validity->AppendData(stats, state.child_appends[0], vdata, count); } -void ListColumnData::RevertAppend(row_t start_row) { - ColumnData::RevertAppend(start_row); - validity.RevertAppend(start_row); +void ListColumnData::RevertAppend(row_t new_count) { + ColumnData::RevertAppend(new_count); + validity->RevertAppend(new_count); auto column_count = GetMaxEntry(); - if (column_count > start) { + if (column_count > 0) { // revert append in the child column auto list_offset = FetchListOffset(column_count - 1); child_column->RevertAppend(UnsafeNumericCast(list_offset)); @@ -263,13 +266,14 @@ idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result throw NotImplementedException("List Fetch"); } -void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ListColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { throw NotImplementedException("List Update is not supported."); } -void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ListColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { throw NotImplementedException("List Update Column is not supported"); } @@ -289,17 +293,17 @@ void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &sta } // now perform the fetch within the segment - auto start_offset = idx_t(row_id) == this->start ? 0 : FetchListOffset(UnsafeNumericCast(row_id - 1)); + auto start_offset = row_id == 0 ? 0 : FetchListOffset(UnsafeNumericCast(row_id - 1)); auto end_offset = FetchListOffset(UnsafeNumericCast(row_id)); - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); - auto &validity = FlatVector::Validity(result); + auto &validity_mask = FlatVector::Validity(result); auto list_data = FlatVector::GetData(result); auto &list_entry = list_data[result_idx]; // set the list entry offset to the size of the current list list_entry.offset = ListVector::GetListSize(result); list_entry.length = end_offset - start_offset; - if (!validity.RowIsValid(result_idx)) { + if (!validity_mask.RowIsValid(result_idx)) { // the list is NULL! no need to fetch the child D_ASSERT(list_entry.length == 0); return; @@ -308,28 +312,45 @@ void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &sta // now we need to read from the child all the elements between [offset...length] auto child_scan_count = list_entry.length; if (child_scan_count > 0) { - auto child_state = make_uniq(); + ColumnScanState child_state(nullptr); auto &child_type = ListType::GetChildType(result.GetType()); Vector child_scan(child_type, child_scan_count); // seek the scan towards the specified position and read [length] entries - child_state->Initialize(child_type, nullptr); - child_column->InitializeScanWithOffset(*child_state, start + start_offset); + child_state.Initialize(state.context, child_type, nullptr); + child_column->InitializeScanWithOffset(child_state, start_offset); D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || - child_state->row_index + child_scan_count - this->start <= child_column->GetMaxEntry()); - child_column->ScanCount(*child_state, child_scan, child_scan_count); + child_state.offset_in_column + child_scan_count <= child_column->GetMaxEntry()); + child_column->ScanCount(child_state, child_scan, child_scan_count); ListVector::Append(result, child_scan, child_scan_count); } } -void ListColumnData::CommitDropColumn() { - ColumnData::CommitDropColumn(); - validity.CommitDropColumn(); - child_column->CommitDropColumn(); +void ListColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + ColumnData::VisitBlockIds(visitor); + validity->VisitBlockIds(visitor); + child_column->VisitBlockIds(visitor); +} + +void ListColumnData::SetValidityData(shared_ptr validity_p) { + if (validity) { + throw InternalException("ListColumnData::SetValidityData cannot be used to overwrite existing validity"); + } + validity_p->SetParent(this); + this->validity = std::move(validity_p); +} + +void ListColumnData::SetChildData(shared_ptr child_column_p) { + if (child_column) { + throw InternalException("ListColumnData::SetChildData cannot be used to overwrite existing data"); + } + child_column_p->SetParent(this); + this->child_column = std::move(child_column_p); } struct ListColumnCheckpointState : public ColumnCheckpointState { - ListColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) + ListColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) : ColumnCheckpointState(row_group, column_data, partial_block_manager) { global_stats = ListStats::CreateEmpty(column_data.type).ToUnique(); } @@ -338,8 +359,25 @@ struct ListColumnCheckpointState : public ColumnCheckpointState { unique_ptr child_state; public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, original_column.type, + ColumnDataType::CHECKPOINT_TARGET, nullptr); + } + + shared_ptr GetFinalResult() override { + if (result_column) { + auto &column_data = result_column->Cast(); + auto validity_child = validity_state->GetFinalResult(); + column_data.SetValidityData(shared_ptr_cast(std::move(validity_child))); + column_data.SetChildData(child_state->GetFinalResult()); + } + return ColumnCheckpointState::GetFinalResult(); + } + unique_ptr GetStatistics() override { auto stats = global_stats->Copy(); + stats.Merge(*validity_state->GetStatistics()); ListStats::SetChildStats(stats, child_state->GetStatistics()); return stats.ToUnique(); } @@ -352,15 +390,15 @@ struct ListColumnCheckpointState : public ColumnCheckpointState { } }; -unique_ptr ListColumnData::CreateCheckpointState(RowGroup &row_group, +unique_ptr ListColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { return make_uniq(row_group, *this, partial_block_manager); } -unique_ptr ListColumnData::Checkpoint(RowGroup &row_group, +unique_ptr ListColumnData::Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { auto base_state = ColumnData::Checkpoint(row_group, checkpoint_info); - auto validity_state = validity.Checkpoint(row_group, checkpoint_info); + auto validity_state = validity->Checkpoint(row_group, checkpoint_info); auto child_state = child_column->Checkpoint(row_group, checkpoint_info); auto &checkpoint_state = base_state->Cast(); @@ -370,34 +408,34 @@ unique_ptr ListColumnData::Checkpoint(RowGroup &row_group } bool ListColumnData::IsPersistent() { - return ColumnData::IsPersistent() && validity.IsPersistent() && child_column->IsPersistent(); + return ColumnData::IsPersistent() && validity->IsPersistent() && child_column->IsPersistent(); } bool ListColumnData::HasAnyChanges() const { - return ColumnData::HasAnyChanges() || validity.HasAnyChanges() || child_column->HasAnyChanges(); + return ColumnData::HasAnyChanges() || validity->HasAnyChanges() || child_column->HasAnyChanges(); } PersistentColumnData ListColumnData::Serialize() { auto persistent_data = ColumnData::Serialize(); - persistent_data.child_columns.push_back(validity.Serialize()); + persistent_data.child_columns.push_back(validity->Serialize()); persistent_data.child_columns.push_back(child_column->Serialize()); return persistent_data; } void ListColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) { ColumnData::InitializeColumn(column_data, target_stats); - validity.InitializeColumn(column_data.child_columns[0], target_stats); + validity->InitializeColumn(column_data.child_columns[0], target_stats); auto &child_stats = ListStats::GetChildStats(target_stats); child_column->InitializeColumn(column_data.child_columns[1], child_stats); } -void ListColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); +void ListColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index 40e20d2d4..f75903635 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/typedefs.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/execution/adaptive_filter.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -11,6 +12,8 @@ #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/checkpoint/table_data_writer.hpp" #include "duckdb/storage/metadata/metadata_reader.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" #include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/column_data.hpp" @@ -25,15 +28,15 @@ namespace duckdb { -RowGroup::RowGroup(RowGroupCollection &collection_p, idx_t start, idx_t count) - : SegmentBase(start, count), collection(collection_p), version_info(nullptr), allocation_size(0), - row_id_is_loaded(false), has_changes(false) { +RowGroup::RowGroup(RowGroupCollection &collection_p, idx_t count) + : SegmentBase(count), collection(collection_p), version_info(nullptr), deletes_is_loaded(false), + allocation_size(0), row_id_is_loaded(false), has_changes(false) { Verify(); } RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) - : SegmentBase(pointer.row_start, pointer.tuple_count), collection(collection_p), version_info(nullptr), - allocation_size(0), row_id_is_loaded(false), has_changes(false) { + : SegmentBase(pointer.tuple_count), collection(collection_p), version_info(nullptr), + deletes_is_loaded(false), allocation_size(0), row_id_is_loaded(false), has_changes(false) { // deserialize the columns if (pointer.data_pointers.size() != collection_p.GetTypes().size()) { throw IOException("Row group column count is unaligned with table column count. Corrupt file?"); @@ -45,7 +48,6 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) this->is_loaded[c] = false; } this->deletes_pointers = std::move(pointer.deletes_pointers); - this->deletes_is_loaded = false; this->has_metadata_blocks = pointer.has_metadata_blocks; this->extra_metadata_blocks = std::move(pointer.extra_metadata_blocks); @@ -53,14 +55,14 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) } RowGroup::RowGroup(RowGroupCollection &collection_p, PersistentRowGroupData &data) - : SegmentBase(data.start, data.count), collection(collection_p), version_info(nullptr), + : SegmentBase(data.count), collection(collection_p), version_info(nullptr), deletes_is_loaded(false), allocation_size(0), row_id_is_loaded(false), has_changes(false) { auto &block_manager = GetBlockManager(); auto &info = GetTableInfo(); auto &types = collection.get().GetTypes(); columns.reserve(types.size()); for (idx_t c = 0; c < types.size(); c++) { - auto entry = ColumnData::CreateColumn(block_manager, info, c, data.start, types[c], nullptr); + auto entry = ColumnData::CreateColumn(block_manager, info, c, types[c]); entry->InitializeColumn(data.column_data[c]); columns.push_back(std::move(entry)); } @@ -68,29 +70,23 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, PersistentRowGroupData &dat Verify(); } -void RowGroup::MoveToCollection(RowGroupCollection &collection_p, idx_t new_start) { +void RowGroup::MoveToCollection(RowGroupCollection &collection_p) { lock_guard l(row_group_lock); - if (start != new_start) { - has_changes = true; - } + // FIXME + // MoveToCollection causes any_changes to be set to true because we are changing the start position of the row group + // the start position is ONLY written when targeting old serialization versions - as such, we don't actually + // need to do this when targeting newer serialization versions + // not doing this could allow metadata reuse in these situations, which would improve vacuuming performance + // especially when vacuuming from the beginning of large tables + has_changes = true; this->collection = collection_p; - this->start = new_start; for (idx_t c = 0; c < columns.size(); c++) { if (is_loaded && !is_loaded[c]) { // we only need to set the column start position if it is already loaded // if it is not loaded - we will set the correct start position upon loading continue; } - columns[c]->SetStart(new_start); - } - if (row_id_is_loaded) { - row_id_column_data->SetStart(new_start); - } - if (!HasUnloadedDeletes()) { - auto vinfo = GetVersionInfo(); - if (vinfo) { - vinfo->SetStart(new_start); - } + columns[c]->SetDataType(ColumnDataType::MAIN_TABLE); } } @@ -113,42 +109,50 @@ idx_t RowGroup::GetRowGroupSize() const { return collection.get().GetRowGroupSize(); } -ColumnData &RowGroup::GetRowIdColumnData() { +void RowGroup::LoadRowIdColumnData() const { if (row_id_is_loaded) { - return *row_id_column_data; + return; } lock_guard l(row_group_lock); - if (!row_id_column_data) { - row_id_column_data = make_uniq(GetBlockManager(), GetTableInfo(), start); - row_id_column_data->count = count.load(); - row_id_is_loaded = true; + if (row_id_column_data) { + return; } - return *row_id_column_data; + row_id_column_data = make_uniq(GetBlockManager(), GetTableInfo()); + row_id_column_data->count = count.load(); + row_id_is_loaded = true; } -ColumnData &RowGroup::GetColumn(const StorageIndex &c) { +ColumnData &RowGroup::GetColumn(const StorageIndex &c) const { return GetColumn(c.GetPrimaryIndex()); } -ColumnData &RowGroup::GetColumn(storage_t c) { +ColumnData &RowGroup::GetColumn(storage_t c) const { + LoadColumn(c); + return c == COLUMN_IDENTIFIER_ROW_ID ? *row_id_column_data : *columns[c]; +} + +void RowGroup::LoadColumn(storage_t c) const { if (c == COLUMN_IDENTIFIER_ROW_ID) { - return GetRowIdColumnData(); + LoadRowIdColumnData(); + return; } D_ASSERT(c < columns.size()); if (!is_loaded) { // not being lazy loaded D_ASSERT(columns[c]); - return *columns[c]; + return; } if (is_loaded[c]) { D_ASSERT(columns[c]); - return *columns[c]; + return; } lock_guard l(row_group_lock); if (columns[c]) { + // another thread loaded the column while we were waiting for the lock D_ASSERT(is_loaded[c]); - return *columns[c]; + return; } + // load the column if (column_pointers.size() != columns.size()) { throw InternalException("Lazy loading a column but the pointer was not set"); } @@ -156,52 +160,63 @@ ColumnData &RowGroup::GetColumn(storage_t c) { auto &types = GetCollection().GetTypes(); auto &block_pointer = column_pointers[c]; MetadataReader column_data_reader(metadata_manager, block_pointer); - this->columns[c] = - ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), c, start, column_data_reader, types[c]); + this->columns[c] = ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), c, column_data_reader, types[c]); is_loaded[c] = true; if (this->columns[c]->count != this->count) { - throw InternalException("Corrupted database - loaded column with index %llu at row start %llu, count %llu did " + throw InternalException("Corrupted database - loaded column with index %llu, count %llu did " "not match count of row group %llu", - c, start, this->columns[c]->count.load(), this->count.load()); + c, this->columns[c]->count.load(), this->count.load()); } - return *columns[c]; } -BlockManager &RowGroup::GetBlockManager() { +BlockManager &RowGroup::GetBlockManager() const { return GetCollection().GetBlockManager(); } -DataTableInfo &RowGroup::GetTableInfo() { +DataTableInfo &RowGroup::GetTableInfo() const { return GetCollection().GetTableInfo(); } -void RowGroup::InitializeEmpty(const vector &types) { +void RowGroup::InitializeEmpty(const vector &types, ColumnDataType data_type) { // set up the segment trees for the column segments D_ASSERT(columns.empty()); for (idx_t i = 0; i < types.size(); i++) { - auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), i, start, types[i]); + auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), i, types[i], data_type); columns.push_back(std::move(column_data)); } } -void ColumnScanState::Initialize(const LogicalType &type, const vector &children, - optional_ptr options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + const vector &children, optional_ptr options) { // Register the options in the state scan_options = options; + context = context_p; if (type.id() == LogicalTypeId::VALIDITY) { // validity - nothing to initialize return; } + + if (type.id() == LogicalTypeId::VARIANT) { + // variant - column scan states are created later + // this is done because the internal shape of the VARIANT is different per rowgroup + scan_child_column.resize(2, true); + return; + } + + D_ASSERT(child_states.empty()); if (type.InternalType() == PhysicalType::STRUCT) { // validity + struct children auto &struct_children = StructType::GetChildTypes(type); - child_states.resize(struct_children.size() + 1); + child_states.reserve(struct_children.size() + 1); + for (idx_t i = 0; i <= struct_children.size(); i++) { + child_states.emplace_back(parent); + } if (children.empty()) { // scan all struct children scan_child_column.resize(struct_children.size(), true); for (idx_t i = 0; i < struct_children.size(); i++) { - child_states[i + 1].Initialize(struct_children[i].second, options); + child_states[i + 1].Initialize(context, struct_children[i].second, options); } } else { // only scan the specified subset of columns @@ -211,61 +226,73 @@ void ColumnScanState::Initialize(const LogicalType &type, const vector options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + optional_ptr options) { vector children; - Initialize(type, children, options); + Initialize(context_p, type, children, options); } -void CollectionScanState::Initialize(const vector &types) { +void CollectionScanState::Initialize(const QueryContext &context, const vector &types) { auto &column_ids = GetColumnIds(); - column_scans = make_unsafe_uniq_array(column_ids.size()); + D_ASSERT(column_scans.empty()); + column_scans.reserve(column_scans.size()); + for (idx_t i = 0; i < column_ids.size(); i++) { + column_scans.emplace_back(*this); + } for (idx_t i = 0; i < column_ids.size(); i++) { if (column_ids[i].IsRowIdColumn()) { continue; } auto col_id = column_ids[i].GetPrimaryIndex(); - column_scans[i].Initialize(types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); + column_scans[i].Initialize(context, types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); } } -bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector_offset) { +bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, SegmentNode &node, idx_t vector_offset) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); if (!CheckZonemap(filters)) { return false; } + if (!RefersToSameObject(node.GetNode(), *this)) { + throw InternalException("RowGroup::InitializeScanWithOffset segment node mismatch"); + } - state.row_group = this; + state.row_group = node; state.vector_index = vector_offset; - state.max_row_group_row = - this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); - auto row_number = start + vector_offset * STANDARD_VECTOR_SIZE; + auto row_start = node.GetRowStart(); + state.max_row_group_row = row_start > state.max_row ? 0 : MinValue(this->count, state.max_row - row_start); + auto row_number = vector_offset * STANDARD_VECTOR_SIZE; if (state.max_row_group_row == 0) { // exceeded row groups to scan return false; } - D_ASSERT(state.column_scans); + D_ASSERT(!state.column_scans.empty()); for (idx_t i = 0; i < column_ids.size(); i++) { const auto &column = column_ids[i]; auto &column_data = GetColumn(column); @@ -275,20 +302,23 @@ bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector return true; } -bool RowGroup::InitializeScan(CollectionScanState &state) { +bool RowGroup::InitializeScan(CollectionScanState &state, SegmentNode &node) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); if (!CheckZonemap(filters)) { return false; } - state.row_group = this; + if (!RefersToSameObject(node.GetNode(), *this)) { + throw InternalException("RowGroup::InitializeScan segment node mismatch"); + } + auto row_start = node.GetRowStart(); + state.row_group = node; state.vector_index = 0; - state.max_row_group_row = - this->start > state.max_row ? 0 : MinValue(this->count, state.max_row - this->start); + state.max_row_group_row = row_start > state.max_row ? 0 : MinValue(this->count, state.max_row - row_start); if (state.max_row_group_row == 0) { return false; } - D_ASSERT(state.column_scans); + D_ASSERT(!state.column_scans.empty()); for (idx_t i = 0; i < column_ids.size(); i++) { auto column = column_ids[i]; auto &column_data = GetColumn(column); @@ -300,18 +330,18 @@ bool RowGroup::InitializeScan(CollectionScanState &state) { unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, const LogicalType &target_type, idx_t changed_idx, ExpressionExecutor &executor, - CollectionScanState &scan_state, DataChunk &scan_chunk) { + CollectionScanState &scan_state, SegmentNode &node, + DataChunk &scan_chunk) { Verify(); // construct a new column data for this type - auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), changed_idx, start, target_type); + auto column_data = ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), changed_idx, target_type); ColumnAppendState append_state; column_data->InitializeAppend(append_state); // scan the original table, and fill the new column with the transformed value - scan_state.Initialize(GetCollection().GetTypes()); - InitializeScan(scan_state); + InitializeScan(scan_state, node); DataChunk append_chunk; vector append_types; @@ -332,7 +362,7 @@ unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, con } // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->start, this->count); + auto row_group = make_uniq(new_collection, this->count); row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); auto &cols = GetColumns(); for (idx_t i = 0; i < cols.size(); i++) { @@ -355,7 +385,7 @@ unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, Col // construct a new column data for the new column auto added_column = - ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), GetColumnCount(), start, new_column.Type()); + ColumnData::CreateColumn(GetBlockManager(), GetTableInfo(), GetColumnCount(), new_column.Type()); idx_t rows_to_write = this->count; if (rows_to_write > 0) { @@ -372,7 +402,7 @@ unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, Col } // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->start, this->count); + auto row_group = make_uniq(new_collection, this->count); row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); row_group->columns = GetColumns(); // now add the new column @@ -387,7 +417,7 @@ unique_ptr RowGroup::RemoveColumn(RowGroupCollection &new_collection, D_ASSERT(removed_column < columns.size()); - auto row_group = make_uniq(new_collection, this->start, this->count); + auto row_group = make_uniq(new_collection, this->count); row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); // copy over all columns except for the removed one auto &cols = GetColumns(); @@ -407,9 +437,21 @@ void RowGroup::CommitDrop() { } } +struct BlockIdDropper : public BlockIdVisitor { + explicit BlockIdDropper(BlockManager &manager) : manager(manager) { + } + + void Visit(block_id_t block_id) override { + manager.MarkBlockAsModified(block_id); + } + + BlockManager &manager; +}; + void RowGroup::CommitDropColumn(const idx_t column_index) { auto &column = GetColumn(column_index); - column.CommitDropColumn(); + BlockIdDropper dropper(GetBlockManager()); + column.VisitBlockIds(dropper); } void RowGroup::NextVector(CollectionScanState &state) { @@ -424,6 +466,7 @@ void RowGroup::NextVector(CollectionScanState &state) { FilterPropagateResult RowGroup::CheckRowIdFilter(const TableFilter &filter, idx_t beg_row, idx_t end_row) { // RowId columns dont have a zonemap, but we can trivially create stats to check the filter against. BaseStatistics dummy_stats = NumericStats::CreateEmpty(LogicalType::ROW_TYPE); + dummy_stats.SetHasNoNullFast(); NumericStats::SetMin(dummy_stats, UnsafeNumericCast(beg_row)); NumericStats::SetMax(dummy_stats, UnsafeNumericCast(end_row)); @@ -457,6 +500,7 @@ bool RowGroup::CheckZonemap(ScanFilterInfo &filters) { bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { auto &filters = state.GetFilterInfo(); + optional_idx target_vector_index_max; for (auto &entry : filters.GetFilterList()) { if (entry.IsAlwaysTrue()) { // filter is always true - avoid checking @@ -478,15 +522,21 @@ bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { // no segment to skip continue; } - idx_t target_row = current_segment->start + current_segment->count; + auto row_start = current_segment->GetRowStart(); + idx_t target_row = row_start + current_segment->GetNode().count; if (target_row >= state.max_row) { target_row = state.max_row; } + D_ASSERT(target_row >= row_start); + D_ASSERT(target_row <= row_start + this->count); + idx_t target_vector_index = (target_row - row_start) / STANDARD_VECTOR_SIZE; - D_ASSERT(target_row >= this->start); - D_ASSERT(target_row <= this->start + this->count); - idx_t target_vector_index = (target_row - this->start) / STANDARD_VECTOR_SIZE; - if (state.vector_index == target_vector_index) { + if (!target_vector_index_max.IsValid() || target_vector_index_max.GetIndex() < target_vector_index) { + target_vector_index_max = target_vector_index; + } + } + if (target_vector_index_max.IsValid()) { + if (state.vector_index == target_vector_index_max.GetIndex()) { // we can't skip any full vectors because this segment contains less than a full vector // for now we just bail-out // FIXME: we could check if we can ALSO skip the next segments, in which case skipping a full vector @@ -495,13 +545,13 @@ bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { // exceedingly rare return true; } - while (state.vector_index < target_vector_index) { + while (state.vector_index < target_vector_index_max.GetIndex()) { NextVector(state); } return false; + } else { + return true; } - - return true; } template @@ -529,19 +579,20 @@ void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &s if (!CheckZonemapSegments(state)) { continue; } + auto ¤t_row_group = state.row_group->GetNode(); // second, scan the version chunk manager to figure out which tuples to load for this transaction idx_t count; if (TYPE == TableScanType::TABLE_SCAN_REGULAR) { - count = state.row_group->GetSelVector(transaction, state.vector_index, state.valid_sel, max_count); + count = current_row_group.GetSelVector(transaction, state.vector_index, state.valid_sel, max_count); if (count == 0) { // nothing to scan for this vector, skip the entire vector NextVector(state); continue; } } else if (TYPE == TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED) { - count = state.row_group->GetCommittedSelVector(transaction.start_time, transaction.transaction_id, - state.vector_index, state.valid_sel, max_count); + count = current_row_group.GetCommittedSelVector(transaction.start_time, transaction.transaction_id, + state.vector_index, state.valid_sel, max_count); if (count == 0) { // nothing to scan for this vector, skip the entire vector NextVector(state); @@ -706,7 +757,7 @@ optional_ptr RowGroup::GetVersionInfo() { } // deletes are not loaded - reload auto root_delete = deletes_pointers[0]; - auto loaded_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager(), start); + auto loaded_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager()); SetVersionInfo(std::move(loaded_info)); deletes_is_loaded = true; return version_info; @@ -721,7 +772,8 @@ shared_ptr RowGroup::GetOrCreateVersionInfoInternal() { // version info does not exist - need to create it lock_guard lock(row_group_lock); if (!owned_version_info) { - auto new_info = make_shared_ptr(start); + auto &buffer_manager = GetBlockManager().GetBufferManager(); + auto new_info = make_shared_ptr(buffer_manager); SetVersionInfo(std::move(new_info)); } return owned_version_info; @@ -745,6 +797,27 @@ RowVersionManager &RowGroup::GetOrCreateVersionInfo() { return *GetOrCreateVersionInfoInternal(); } +optional_ptr RowGroup::GetVersionInfoIfLoaded() const { + if (!HasUnloadedDeletes()) { + // deletes are loaded - return the version info + return version_info; + } + return nullptr; +} + +bool RowGroup::ShouldCheckpointRowGroup(transaction_t checkpoint_id) const { + if (checkpoint_id == MAX_TRANSACTION_ID) { + // no id specified - checkpoint all committed data + return true; + } + // check if this row group was committed as of the current checkpoint id + auto vinfo = GetVersionInfoIfLoaded(); + if (!vinfo) { + return true; + } + return vinfo->ShouldCheckpointRowGroup(checkpoint_id, count); +} + idx_t RowGroup::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count) { auto vinfo = GetVersionInfo(); @@ -764,7 +837,9 @@ idx_t RowGroup::GetCommittedSelVector(transaction_t start_time, transaction_t tr } bool RowGroup::Fetch(TransactionData transaction, idx_t row) { - D_ASSERT(row < this->count); + if (UnsafeNumericCast(row) > count) { + throw InternalException("RowGroup::Fetch - row_id out of range for row group"); + } auto vinfo = GetVersionInfo(); if (!vinfo) { return true; @@ -774,6 +849,9 @@ bool RowGroup::Fetch(TransactionData transaction, idx_t row) { void RowGroup::FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, row_t row_id, DataChunk &result, idx_t result_idx) { + if (UnsafeNumericCast(row_id) > count) { + throw InternalException("RowGroup::FetchRow - row_id out of range for row group"); + } for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { auto &column = column_ids[col_idx]; auto &result_vector = result.data[col_idx]; @@ -814,13 +892,16 @@ void RowGroup::CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_ vinfo.CommitAppend(commit_id, row_group_start, count); } -void RowGroup::RevertAppend(idx_t row_group_start) { +void RowGroup::RevertAppend(idx_t new_count) { + if (new_count > this->count) { + throw InternalException("RowGroup::RevertAppend new_count out of range"); + } auto &vinfo = GetOrCreateVersionInfo(); - vinfo.RevertAppend(row_group_start - this->start); + vinfo.RevertAppend(new_count); for (auto &column : GetColumns()) { - column->RevertAppend(UnsafeNumericCast(row_group_start)); + column->RevertAppend(UnsafeNumericCast(new_count)); } - SetCount(MinValue(row_group_start - this->start, this->count)); + SetCount(new_count); Verify(); } @@ -852,11 +933,11 @@ void RowGroup::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_ vinfo.CleanupAppend(lowest_transaction, start, count); } -void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids) { +void RowGroup::Update(TransactionData transaction, DataTable &data_table, DataChunk &update_chunk, row_t *ids, + idx_t offset, idx_t count, const vector &column_ids, idx_t row_group_start) { #ifdef DEBUG for (size_t i = offset; i < offset + count; i++) { - D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); + D_ASSERT(ids[i] >= row_t(row_group_start) && ids[i] < row_t(row_group_start + this->count)); } #endif for (idx_t i = 0; i < column_ids.size(); i++) { @@ -866,33 +947,36 @@ void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_ if (offset > 0) { Vector sliced_vector(update_chunk.data[i], offset, offset + count); sliced_vector.Flatten(count); - col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); + col_data.Update(transaction, data_table, column.index, sliced_vector, ids + offset, count, row_group_start); } else { - col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); + col_data.Update(transaction, data_table, column.index, update_chunk.data[i], ids, count, row_group_start); } MergeStatistics(column.index, *col_data.GetUpdateStatistics()); } } -void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path) { +void RowGroup::UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path, idx_t row_group_start) { D_ASSERT(updates.ColumnCount() == 1); auto ids = FlatVector::GetData(row_ids); auto primary_column_idx = column_path[0]; D_ASSERT(primary_column_idx < columns.size()); auto &col_data = GetColumn(primary_column_idx); + idx_t depth = 1; if (offset > 0) { Vector sliced_vector(updates.data[0], offset, offset + count); sliced_vector.Flatten(count); - col_data.UpdateColumn(transaction, column_path, sliced_vector, ids + offset, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, sliced_vector, ids + offset, count, depth, + row_group_start); } else { - col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, updates.data[0], ids, count, depth, + row_group_start); } MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); } -unique_ptr RowGroup::GetStatistics(idx_t column_idx) { +unique_ptr RowGroup::GetStatistics(idx_t column_idx) const { auto &col_data = GetColumn(column_idx); return col_data.GetStatistics(); } @@ -914,12 +998,38 @@ void RowGroup::MergeIntoStatistics(TableStatistics &other) { } } +ColumnCheckpointInfo::ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) + : column_idx(column_idx), info(info) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + CheckpointOptions options_p) + : manager(manager), compression_types(compression_types), options(options_p) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p) + : manager(manager), compression_types(compression_types), options(), + column_partial_block_managers(column_partial_block_managers_p) { +} + +PartialBlockManager &RowGroupWriteInfo::GetPartialBlockManager(idx_t column_idx) { + if (column_partial_block_managers && !column_partial_block_managers->empty()) { + return *column_partial_block_managers->at(column_idx); + } + return manager; +} + +PartialBlockManager &ColumnCheckpointInfo::GetPartialBlockManager() { + return info.GetPartialBlockManager(column_idx); +} + CompressionType ColumnCheckpointInfo::GetCompressionType() { return info.compression_types[column_idx]; } vector RowGroup::WriteToDisk(RowGroupWriteInfo &info, - const vector> &row_groups) { + const vector> &row_groups) { vector result; if (row_groups.empty()) { return result; @@ -931,10 +1041,12 @@ vector RowGroup::WriteToDisk(RowGroupWriteInfo &info, RowGroupWriteData write_data; write_data.states.reserve(column_count); write_data.statistics.reserve(column_count); + write_data.should_checkpoint = row_group.get().ShouldCheckpointRowGroup(info.options.transaction_id); result.push_back(std::move(write_data)); } // Checkpoint the row groups + // In order to co-locate columns across different row groups, we write column-at-a-time // i.e. we first write column #0 of all row groups, then column #1, ... @@ -945,30 +1057,53 @@ vector RowGroup::WriteToDisk(RowGroupWriteInfo &info, // Some of these columns are composite (list, struct). The data is written // first sequentially, and the pointers are written later, so that the // pointers all end up densely packed, and thus more cache-friendly. + vector>> result_columns; + result_columns.resize(row_groups.size()); for (idx_t column_idx = 0; column_idx < column_count; column_idx++) { for (idx_t row_group_idx = 0; row_group_idx < row_groups.size(); row_group_idx++) { auto &row_group = row_groups[row_group_idx].get(); auto &row_group_write_data = result[row_group_idx]; - auto &column = row_group.GetColumn(column_idx); - if (column.start != row_group.start) { - throw InternalException("RowGroup::WriteToDisk - child-column is unaligned with row group"); + if (!row_group_write_data.should_checkpoint) { + // row group should not be checkpointed - skip + continue; } + auto &column = row_group.GetColumn(column_idx); ColumnCheckpointInfo checkpoint_info(info, column_idx); auto checkpoint_state = column.Checkpoint(row_group, checkpoint_info); - D_ASSERT(checkpoint_state); + auto result_col = checkpoint_state->GetFinalResult(); + // FIXME: we should get rid of the checkpoint state statistics - and instead use the stats in the ColumnData + // directly auto stats = checkpoint_state->GetStatistics(); - D_ASSERT(stats); + result_col->MergeStatistics(*stats); + result_columns[row_group_idx].push_back(std::move(result_col)); row_group_write_data.statistics.push_back(stats->Copy()); row_group_write_data.states.push_back(std::move(checkpoint_state)); } } + + // create the row groups + for (idx_t row_group_idx = 0; row_group_idx < row_groups.size(); row_group_idx++) { + auto &row_group_write_data = result[row_group_idx]; + auto &row_group = row_groups[row_group_idx].get(); + if (!row_group_write_data.should_checkpoint) { + // row group should not be checkpointed - skip + continue; + } + auto result_row_group = make_shared_ptr(row_group.GetCollection(), row_group.count); + result_row_group->columns = std::move(result_columns[row_group_idx]); + result_row_group->version_info = row_group.version_info.load(); + result_row_group->owned_version_info = row_group.owned_version_info; + + row_group_write_data.result_row_group = std::move(result_row_group); + } + return result; } -RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriteInfo &info) { - vector> row_groups; +RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriteInfo &info) const { + vector> row_groups; row_groups.push_back(*this); auto result = WriteToDisk(info, row_groups); return std::move(result[0]); @@ -991,21 +1126,15 @@ bool RowGroup::HasUnloadedDeletes() const { return !deletes_is_loaded; } -vector RowGroup::GetColumnPointers() { - if (has_metadata_blocks) { - // we have the column metadata from the file itself - no need to deserialize metadata to fetch it - // read if from "column_pointers" and "extra_metadata_blocks" - auto result = column_pointers; - for (auto &block_pointer : extra_metadata_blocks) { - result.emplace_back(block_pointer, 0); - } - return result; +vector RowGroup::GetOrComputeExtraMetadataBlocks(bool force_compute) { + if (has_metadata_blocks && !force_compute) { + return extra_metadata_blocks; } - vector result; if (column_pointers.empty()) { // no pointers - return result; + return {}; } + vector read_pointers; // column_pointers stores the beginning of each column // if columns are big - they may span multiple metadata blocks // we need to figure out all blocks that this row group points to @@ -1016,13 +1145,25 @@ vector RowGroup::GetColumnPointers() { // for all but the last column pointer - we can just follow the linked list until we reach the last column MetadataReader reader(metadata_manager, column_pointers[0]); auto last_pointer = column_pointers[last_idx]; - result = reader.GetRemainingBlocks(last_pointer); + read_pointers = reader.GetRemainingBlocks(last_pointer); } // for the last column we need to deserialize the column - because we don't know where it stops auto &types = GetCollection().GetTypes(); - MetadataReader reader(metadata_manager, column_pointers[last_idx], &result); - ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), last_idx, start, reader, types[last_idx]); - return result; + MetadataReader reader(metadata_manager, column_pointers[last_idx], &read_pointers); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), last_idx, reader, types[last_idx]); + + unordered_set result_as_set; + for (auto &ptr : read_pointers) { + result_as_set.emplace(ptr.block_pointer); + } + for (auto &ptr : column_pointers) { + result_as_set.erase(ptr.block_pointer); + } + return {result_as_set.begin(), result_as_set.end()}; +} + +const vector &RowGroup::GetColumnStartPointers() const { + return column_pointers; } RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { @@ -1031,7 +1172,8 @@ RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { // we have existing metadata and the row group has not been changed // re-use previous metadata RowGroupWriteData result; - result.existing_pointers = GetColumnPointers(); + result.reuse_existing_metadata_blocks = true; + result.existing_extra_metadata_blocks = GetOrComputeExtraMetadataBlocks(); return result; } auto &compression_types = writer.GetCompressionTypes(); @@ -1046,27 +1188,52 @@ RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { column_idx, this->count.load(), column.count.load()); } } - - RowGroupWriteInfo info(writer.GetPartialBlockManager(), compression_types, writer.GetCheckpointType()); + RowGroupWriteInfo info(writer.GetPartialBlockManager(), compression_types, writer.GetCheckpointOptions()); return WriteToDisk(info); } +void IncrementSegmentStart(PersistentColumnData &data, idx_t start_increment) { + for (auto &pointer : data.pointers) { + pointer.row_start += start_increment; + } + for (auto &child_column : data.child_columns) { + IncrementSegmentStart(child_column, start_increment); + } +} + RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWriter &writer, - TableStatistics &global_stats) { + TableStatistics &global_stats, idx_t row_group_start) { RowGroupPointer row_group_pointer; auto metadata_manager = writer.GetMetadataManager(); // construct the row group pointer and write the column meta data to disk - row_group_pointer.row_start = start; + row_group_pointer.row_start = row_group_start; row_group_pointer.tuple_count = count; - if (!write_data.existing_pointers.empty()) { + if (write_data.reuse_existing_metadata_blocks) { // we are re-using the previous metadata row_group_pointer.data_pointers = column_pointers; - row_group_pointer.has_metadata_blocks = has_metadata_blocks; - row_group_pointer.extra_metadata_blocks = extra_metadata_blocks; - row_group_pointer.deletes_pointers = deletes_pointers; - metadata_manager->ClearModifiedBlocks(write_data.existing_pointers); - metadata_manager->ClearModifiedBlocks(deletes_pointers); + row_group_pointer.has_metadata_blocks = true; + row_group_pointer.extra_metadata_blocks = write_data.existing_extra_metadata_blocks; + row_group_pointer.deletes_pointers = CheckpointDeletes(*metadata_manager); + if (metadata_manager) { + vector extra_metadata_block_pointers; + extra_metadata_block_pointers.reserve(write_data.existing_extra_metadata_blocks.size()); + for (auto &block_pointer : write_data.existing_extra_metadata_blocks) { + extra_metadata_block_pointers.emplace_back(block_pointer, 0); + } + metadata_manager->ClearModifiedBlocks(column_pointers); + metadata_manager->ClearModifiedBlocks(extra_metadata_block_pointers); + metadata_manager->ClearModifiedBlocks(deletes_pointers); + + // remember metadata_blocks to avoid loading them on future checkpoints + has_metadata_blocks = true; + extra_metadata_blocks = row_group_pointer.extra_metadata_blocks; + } + // merge row group stats into the global stats + auto lock = global_stats.GetLock(); + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + GetColumn(column_idx).MergeIntoStatistics(global_stats.GetStats(*lock, column_idx).Statistics()); + } return row_group_pointer; } D_ASSERT(write_data.states.size() == columns.size()); @@ -1093,6 +1260,9 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite // Just as above, the state can refer to many other states, so this // can cascade recursively into more pointer writes. auto persistent_data = state->ToPersistentData(); + // increment the "start" in all data pointers by the row group start + // FIXME: this is only necessary when targeting old serialization + IncrementSegmentStart(persistent_data, row_group_start); BinarySerializer serializer(data_writer); serializer.Begin(); persistent_data.Serialize(serializer); @@ -1109,15 +1279,15 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite } // this metadata block is not stored - add it to the extra metadata blocks row_group_pointer.extra_metadata_blocks.push_back(column_pointer.block_pointer); + metadata_blocks.insert(column_pointer.block_pointer); + } + if (metadata_manager) { + row_group_pointer.deletes_pointers = CheckpointDeletes(*metadata_manager); } // set up the pointers correctly within this row group for future operations column_pointers = row_group_pointer.data_pointers; has_metadata_blocks = true; extra_metadata_blocks = row_group_pointer.extra_metadata_blocks; - - if (metadata_manager) { - row_group_pointer.deletes_pointers = CheckpointDeletes(*metadata_manager); - } Verify(); return row_group_pointer; } @@ -1126,7 +1296,8 @@ bool RowGroup::HasChanges() const { if (has_changes) { return true; } - if (version_info.load()) { + auto version_info_loaded = version_info.load(); + if (version_info_loaded && version_info_loaded->HasUnserializedChanges()) { // we have deletes return true; } @@ -1153,13 +1324,13 @@ bool RowGroup::IsPersistent() const { return true; } -PersistentRowGroupData RowGroup::SerializeRowGroupInfo() const { +PersistentRowGroupData RowGroup::SerializeRowGroupInfo(idx_t row_group_start) const { // all columns are persistent - serialize PersistentRowGroupData result; for (auto &col : columns) { result.column_data.push_back(col->Serialize()); } - result.start = start; + result.start = row_group_start; result.count = count; return result; } @@ -1204,26 +1375,57 @@ RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { //===--------------------------------------------------------------------===// // GetPartitionStats //===--------------------------------------------------------------------===// -PartitionStatistics RowGroup::GetPartitionStats() const { +struct DuckDBPartitionRowGroup : public PartitionRowGroup { + explicit DuckDBPartitionRowGroup(const RowGroup &row_group_p, bool is_exact_p) + : row_group(row_group_p), is_exact(is_exact_p) { + } + + const RowGroup &row_group; + const bool is_exact; + + unique_ptr GetColumnStatistics(column_t column_id) override { + return row_group.GetStatistics(column_id); + } + + bool MinMaxIsExact(const BaseStatistics &stats) override { + if (!is_exact || row_group.HasChanges()) { + return false; + } + if (stats.GetStatsType() == StatisticsType::STRING_STATS) { + if (!StringStats::HasMaxStringLength(stats)) { + return false; + } + const idx_t max_length = StringStats::MaxStringLength(stats); + return max_length == StringStats::Max(stats).length() && max_length == StringStats::Min(stats).length(); + } + return stats.GetStatsType() == StatisticsType::NUMERIC_STATS; + } +}; + +PartitionStatistics RowGroup::GetPartitionStats(idx_t row_group_start) { PartitionStatistics result; - result.row_start = start; + result.row_start = row_group_start; result.count = count; if (HasUnloadedDeletes() || version_info.load().get()) { // we have version info - approx count result.count_type = CountType::COUNT_APPROXIMATE; + result.partition_row_group = make_shared_ptr(*this, false); } else { result.count_type = CountType::COUNT_EXACT; + result.partition_row_group = make_shared_ptr(*this, true); } + return result; } //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -void RowGroup::GetColumnSegmentInfo(idx_t row_group_index, vector &result) { +void RowGroup::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, + vector &result) { for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { auto &col_data = GetColumn(col_idx); - col_data.GetColumnSegmentInfo(row_group_index, {col_idx}, result); + col_data.GetColumnSegmentInfo(context, row_group_index, {col_idx}, result); } } @@ -1252,14 +1454,14 @@ class VersionDeleteState { void Flush(); }; -idx_t RowGroup::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count) { - VersionDeleteState del_state(*this, transaction, table, this->start); +idx_t RowGroup::Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count, idx_t row_group_start) { + VersionDeleteState del_state(*this, transaction, table, row_group_start); // obtain a write lock for (idx_t i = 0; i < count; i++) { D_ASSERT(ids[i] >= 0); - D_ASSERT(idx_t(ids[i]) >= this->start && idx_t(ids[i]) < this->start + this->count); - del_state.Delete(ids[i] - UnsafeNumericCast(this->start)); + D_ASSERT(idx_t(ids[i]) >= row_group_start && idx_t(ids[i]) < row_group_start + this->count); + del_state.Delete(ids[i] - UnsafeNumericCast(row_group_start)); } del_state.Flush(); return del_state.delete_count; diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index d44ca0544..699f50c94 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -14,17 +14,20 @@ #include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/persistent_table_data.hpp" #include "duckdb/storage/table/row_group_segment_tree.hpp" +#include "duckdb/storage/table/row_version_manager.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/execution/index/art/art.hpp" namespace duckdb { //===--------------------------------------------------------------------===// // Row Group Segment Tree //===--------------------------------------------------------------------===// -RowGroupSegmentTree::RowGroupSegmentTree(RowGroupCollection &collection) - : SegmentTree(), collection(collection), current_row_group(0), max_row_group(0) { +RowGroupSegmentTree::RowGroupSegmentTree(RowGroupCollection &collection, idx_t base_row_id) + : SegmentTree(base_row_id), collection(collection), current_row_group(0), max_row_group(0) { } RowGroupSegmentTree::~RowGroupSegmentTree() { } @@ -38,7 +41,7 @@ void RowGroupSegmentTree::Initialize(PersistentTableData &data) { root_pointer = data.block_pointer; } -unique_ptr RowGroupSegmentTree::LoadSegment() { +shared_ptr RowGroupSegmentTree::LoadSegment() const { if (current_row_group >= max_row_group) { reader.reset(); finished_loading = true; @@ -49,7 +52,7 @@ unique_ptr RowGroupSegmentTree::LoadSegment() { auto row_group_pointer = RowGroup::Deserialize(deserializer); deserializer.End(); current_row_group++; - return make_uniq(collection, std::move(row_group_pointer)); + return make_shared_ptr(collection, std::move(row_group_pointer)); } //===--------------------------------------------------------------------===// @@ -62,11 +65,11 @@ RowGroupCollection::RowGroupCollection(shared_ptr info_p, TableIO } RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockManager &block_manager, - vector types_p, idx_t row_start_p, idx_t total_rows_p, + vector types_p, idx_t row_start, idx_t total_rows_p, idx_t row_group_size_p) : block_manager(block_manager), row_group_size(row_group_size_p), total_rows(total_rows_p), info(std::move(info_p)), - types(std::move(types_p)), row_start(row_start_p), allocation_size(0), requires_new_row_group(false) { - row_groups = make_shared_ptr(*this); + types(std::move(types_p)), owned_row_groups(make_shared_ptr(*this, row_start)), + allocation_size(0), requires_new_row_group(false) { } idx_t RowGroupCollection::GetTotalRows() const { @@ -89,14 +92,24 @@ MetadataManager &RowGroupCollection::GetMetadataManager() { return GetBlockManager().GetMetadataManager(); } +shared_ptr RowGroupCollection::GetRowGroups() const { + lock_guard guard(row_group_pointer_lock); + return owned_row_groups; +} + +void RowGroupCollection::SetRowGroups(shared_ptr new_row_groups) { + lock_guard guard(row_group_pointer_lock); + owned_row_groups = std::move(new_row_groups); +} + //===--------------------------------------------------------------------===// // Initialize //===--------------------------------------------------------------------===// void RowGroupCollection::Initialize(PersistentTableData &data) { - D_ASSERT(this->row_start == 0); - auto l = row_groups->Lock(); + D_ASSERT(owned_row_groups->GetBaseRowId() == 0); + auto l = owned_row_groups->Lock(); this->total_rows = data.total_rows; - row_groups->Initialize(data); + owned_row_groups->Initialize(data); stats.Initialize(types, data); metadata_pointer = data.base_table_pointer; } @@ -107,12 +120,12 @@ void RowGroupCollection::FinalizeCheckpoint(MetaBlockPointer pointer) { void RowGroupCollection::Initialize(PersistentCollectionData &data) { stats.InitializeEmpty(types); - auto l = row_groups->Lock(); + auto l = owned_row_groups->Lock(); for (auto &row_group_data : data.row_group_data) { auto row_group = make_uniq(*this, row_group_data); row_group->MergeIntoStatistics(stats); total_rows += row_group->count; - row_groups->AppendSegment(l, std::move(row_group)); + owned_row_groups->AppendSegment(l, std::move(row_group), row_group_data.start); } } @@ -124,26 +137,49 @@ void RowGroupCollection::InitializeEmpty() { stats.InitializeEmpty(types); } +ColumnDataType GetColumnDataType(idx_t row_start) { + if (row_start == UnsafeNumericCast(MAX_ROW_ID)) { + return ColumnDataType::INITIAL_TRANSACTION_LOCAL; + } + if (row_start > UnsafeNumericCast(MAX_ROW_ID)) { + return ColumnDataType::TRANSACTION_LOCAL; + } + return ColumnDataType::MAIN_TABLE; +} + void RowGroupCollection::AppendRowGroup(SegmentLock &l, idx_t start_row) { - D_ASSERT(start_row >= row_start); - auto new_row_group = make_uniq(*this, start_row, 0U); - new_row_group->InitializeEmpty(types); - row_groups->AppendSegment(l, std::move(new_row_group)); + auto new_row_group = make_uniq(*this, 0U); + new_row_group->InitializeEmpty(types, GetColumnDataType(start_row)); + owned_row_groups->AppendSegment(l, std::move(new_row_group), start_row); requires_new_row_group = false; } optional_ptr RowGroupCollection::GetRowGroup(int64_t index) { - return row_groups->GetSegmentByIndex(index); + auto result = owned_row_groups->GetSegmentByIndex(index); + if (!result) { + return nullptr; + } + return result->GetNode(); +} + +void RowGroupCollection::SetRowGroup(int64_t index, shared_ptr new_row_group) { + auto result = owned_row_groups->GetSegmentByIndex(index); + if (!result) { + throw InternalException("RowGroupCollection::SetRowGroup - Segment is out of range"); + } + result->SetNode(std::move(new_row_group)); } void RowGroupCollection::Verify() { #ifdef DEBUG idx_t current_total_rows = 0; + auto row_groups = GetRowGroups(); row_groups->Verify(); - for (auto &row_group : row_groups->Segments()) { + for (auto &entry : row_groups->SegmentNodes()) { + auto &row_group = entry.GetNode(); row_group.Verify(); D_ASSERT(&row_group.GetCollection() == this); - D_ASSERT(row_group.start == this->row_start + current_total_rows); + D_ASSERT(entry.GetRowStart() == row_groups->GetBaseRowId() + current_total_rows); current_total_rows += row_group.count; } D_ASSERT(current_total_rows == total_rows.load()); @@ -153,87 +189,97 @@ void RowGroupCollection::Verify() { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -void RowGroupCollection::InitializeScan(CollectionScanState &state, const vector &column_ids, +void RowGroupCollection::InitializeScan(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, optional_ptr table_filters) { - auto row_group = row_groups->GetRootSegment(); + state.row_groups = GetRowGroups(); + auto row_group = state.GetRootSegment(); D_ASSERT(row_group); - state.row_groups = row_groups.get(); - state.max_row = row_start + total_rows; - state.Initialize(GetTypes()); - while (row_group && !row_group->InitializeScan(state)) { - row_group = row_groups->GetNextSegment(row_group); + state.max_row = state.row_groups->GetBaseRowId() + total_rows; + state.Initialize(context, GetTypes()); + while (row_group && !row_group->GetNode().InitializeScan(state, *row_group)) { + row_group = state.GetNextRowGroup(*row_group); } } void RowGroupCollection::InitializeCreateIndexScan(CreateIndexScanState &state) { - state.segment_lock = row_groups->Lock(); + state.row_groups = GetRowGroups(); + state.segment_lock = state.row_groups->Lock(); } -void RowGroupCollection::InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, - idx_t start_row, idx_t end_row) { - auto row_group = row_groups->GetSegment(start_row); +void RowGroupCollection::InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, + idx_t end_row) { + state.row_groups = GetRowGroups(); + auto row_group = state.row_groups->GetSegment(start_row); D_ASSERT(row_group); - state.row_groups = row_groups.get(); state.max_row = end_row; - state.Initialize(GetTypes()); - idx_t start_vector = (start_row - row_group->start) / STANDARD_VECTOR_SIZE; - if (!row_group->InitializeScanWithOffset(state, start_vector)) { + state.Initialize(context, GetTypes()); + idx_t start_vector = (start_row - row_group->GetRowStart()) / STANDARD_VECTOR_SIZE; + if (!row_group->GetNode().InitializeScanWithOffset(state, *row_group, start_vector)) { throw InternalException("Failed to initialize row group scan with offset"); } } -bool RowGroupCollection::InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row) { +bool RowGroupCollection::InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, SegmentNode &row_group, + idx_t vector_index, idx_t max_row) { state.max_row = max_row; - state.row_groups = collection.row_groups.get(); - if (!state.column_scans) { + state.row_groups = collection.GetRowGroups(); + if (state.column_scans.empty()) { // initialize the scan state - state.Initialize(collection.GetTypes()); + state.Initialize(context, collection.GetTypes()); } - return row_group.InitializeScanWithOffset(state, vector_index); + return row_group.GetNode().InitializeScanWithOffset(state, row_group, vector_index); } void RowGroupCollection::InitializeParallelScan(ParallelCollectionScanState &state) { state.collection = this; - state.current_row_group = row_groups->GetRootSegment(); + state.row_groups = GetRowGroups(); + state.current_row_group = state.GetRootSegment(*state.row_groups); state.vector_index = 0; - state.max_row = row_start + total_rows; + state.max_row = state.row_groups->GetBaseRowId() + total_rows; state.batch_index = 0; state.processed_rows = 0; } bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, CollectionScanState &scan_state) { + AssignSharedPointer(scan_state.row_groups, state.row_groups); while (true) { idx_t vector_index; idx_t max_row; - RowGroupCollection *collection; - RowGroup *row_group; + optional_ptr collection; + optional_ptr> row_group; { // select the next row group to scan from the parallel state lock_guard l(state.lock); - if (!state.current_row_group || state.current_row_group->count == 0) { + if (!state.current_row_group) { // no more data left to scan break; } + auto ¤t_row_group = state.current_row_group->GetNode(); + if (current_row_group.count == 0) { + break; + } + auto row_start = state.current_row_group->GetRowStart(); collection = state.collection; row_group = state.current_row_group; if (ClientConfig::GetConfig(context).verify_parallelism) { vector_index = state.vector_index; - max_row = state.current_row_group->start + - MinValue(state.current_row_group->count, - STANDARD_VECTOR_SIZE * state.vector_index + STANDARD_VECTOR_SIZE); - D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < state.current_row_group->count); + max_row = row_start + MinValue(current_row_group.count, + STANDARD_VECTOR_SIZE * state.vector_index + STANDARD_VECTOR_SIZE); + D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < current_row_group.count); state.vector_index++; - if (state.vector_index * STANDARD_VECTOR_SIZE >= state.current_row_group->count) { - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + if (state.vector_index * STANDARD_VECTOR_SIZE >= current_row_group.count) { + state.current_row_group = state.GetNextRowGroup(*state.row_groups, *row_group).get(); state.vector_index = 0; } } else { - state.processed_rows += state.current_row_group->count; + state.processed_rows += current_row_group.count; vector_index = 0; - max_row = state.current_row_group->start + state.current_row_group->count; - state.current_row_group = row_groups->GetNextSegment(state.current_row_group); + max_row = row_start + current_row_group.count; + state.current_row_group = state.GetNextRowGroup(*state.row_groups, *row_group).get(); } max_row = MinValue(max_row, state.max_row); scan_state.batch_index = ++state.batch_index; @@ -242,7 +288,8 @@ bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollec D_ASSERT(row_group); // initialize the scan for this row group - bool need_to_scan = InitializeScanInRowGroup(scan_state, *collection, *row_group, vector_index, max_row); + bool need_to_scan = + InitializeScanInRowGroup(context, scan_state, *collection, *row_group, vector_index, max_row); if (!need_to_scan) { // skip this row group continue; @@ -266,7 +313,7 @@ bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector(row_identifiers); idx_t count = 0; + auto row_groups = GetRowGroups(); for (idx_t i = 0; i < fetch_count; i++) { auto row_id = row_ids[i]; - RowGroup *row_group; + optional_ptr> row_group; { idx_t segment_index; auto l = row_groups->Lock(); @@ -309,17 +357,22 @@ void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, c } row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } - if (!row_group->Fetch(transaction, UnsafeNumericCast(row_id) - row_group->start)) { + auto ¤t_row_group = row_group->GetNode(); + auto offset_in_row_group = UnsafeNumericCast(row_id) - row_group->GetRowStart(); + if (!current_row_group.Fetch(transaction, offset_in_row_group)) { continue; } - row_group->FetchRow(transaction, state, column_ids, row_id, result, count); + state.row_group = row_group; + current_row_group.FetchRow(transaction, state, column_ids, UnsafeNumericCast(offset_in_row_group), + result, count); count++; } result.SetCardinality(count); } bool RowGroupCollection::CanFetch(TransactionData transaction, const row_t row_id) { - RowGroup *row_group; + auto row_groups = GetRowGroups(); + optional_ptr> row_group; { idx_t segment_index; auto l = row_groups->Lock(); @@ -328,7 +381,9 @@ bool RowGroupCollection::CanFetch(TransactionData transaction, const row_t row_i } row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } - return row_group->Fetch(transaction, UnsafeNumericCast(row_id) - row_group->start); + auto ¤t_row_group = row_group->GetNode(); + auto offset_in_row_group = UnsafeNumericCast(row_id) - row_group->GetRowStart(); + return current_row_group.Fetch(transaction, offset_in_row_group); } //===--------------------------------------------------------------------===// @@ -343,11 +398,8 @@ TableAppendState::~TableAppendState() { } bool RowGroupCollection::IsEmpty() const { + auto row_groups = GetRowGroups(); auto l = row_groups->Lock(); - return IsEmpty(l); -} - -bool RowGroupCollection::IsEmpty(SegmentLock &l) const { return row_groups->IsEmpty(l); } @@ -357,15 +409,18 @@ void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppe state.total_append_count = 0; // start writing to the row_groups - auto l = row_groups->Lock(); - if (IsEmpty(l) || requires_new_row_group) { + state.row_groups = GetRowGroups(); + auto l = state.row_groups->Lock(); + if (state.row_groups->IsEmpty(l) || requires_new_row_group) { // empty row group collection: empty first row group - AppendRowGroup(l, row_start + total_rows); + AppendRowGroup(l, state.row_groups->GetBaseRowId() + total_rows); } - state.start_row_group = row_groups->GetLastSegment(l); - D_ASSERT(this->row_start + total_rows == state.start_row_group->start + state.start_row_group->count); - state.start_row_group->InitializeAppend(state.row_group_append_state); + state.start_row_group = state.row_groups->GetLastSegment(l); + D_ASSERT(state.row_groups->GetBaseRowId() + total_rows == + state.start_row_group->GetRowStart() + state.start_row_group->GetNode().count); + state.start_row_group->GetNode().InitializeAppend(state.row_group_append_state); state.transaction = transaction; + state.row_group_start = state.start_row_group->GetRowStart(); // initialize thread-local stats so we have less lock contention when updating distinct statistics state.stats = TableStatistics(); @@ -399,27 +454,26 @@ bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { current_row_group->MergeIntoStatistics(stats); } remaining -= append_count; - if (remaining > 0) { - // we expect max 1 iteration of this loop (i.e. a single chunk should never overflow more than one - // row_group) - D_ASSERT(chunk.size() == remaining + append_count); - // slice the input chunk - if (remaining < chunk.size()) { - chunk.Slice(append_count, remaining); - } - // append a new row_group - new_row_group = true; - auto next_start = current_row_group->start + state.row_group_append_state.offset_in_row_group; - - auto l = row_groups->Lock(); - AppendRowGroup(l, next_start); - // set up the append state for this row_group - auto last_row_group = row_groups->GetLastSegment(l); - last_row_group->InitializeAppend(state.row_group_append_state); - continue; - } else { + if (remaining == 0) { break; } + // we expect max 1 iteration of this loop (i.e. a single chunk should never overflow more than one + // row_group) + D_ASSERT(chunk.size() == remaining + append_count); + // slice the input chunk + if (remaining < chunk.size()) { + chunk.Slice(append_count, remaining); + } + // append a new row_group + new_row_group = true; + auto next_start = state.row_group_start + state.row_group_append_state.offset_in_row_group; + + auto l = state.row_groups->Lock(); + AppendRowGroup(l, next_start); + // set up the append state for this row_group + auto last_row_group = state.row_groups->GetLastSegment(l); + last_row_group->GetNode().InitializeAppend(state.row_group_append_state); + state.row_group_start = next_start; } state.current_row += row_t(total_append_count); @@ -439,10 +493,11 @@ void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppend auto remaining = state.total_append_count; auto row_group = state.start_row_group; while (remaining > 0) { - auto append_count = MinValue(remaining, row_group_size - row_group->count); - row_group->AppendVersionInfo(transaction, append_count); + auto ¤t_row_group = row_group->GetNode(); + auto append_count = MinValue(remaining, row_group_size - current_row_group.count); + current_row_group.AppendVersionInfo(transaction, append_count); remaining -= append_count; - row_group = row_groups->GetNextSegment(row_group); + row_group = state.row_groups->GetNextSegment(*row_group); } total_rows += state.total_append_count; @@ -467,27 +522,30 @@ void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppend } void RowGroupCollection::CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count) { + auto row_groups = GetRowGroups(); auto row_group = row_groups->GetSegment(row_start); D_ASSERT(row_group); idx_t current_row = row_start; idx_t remaining = count; while (true) { - idx_t start_in_row_group = current_row - row_group->start; - idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); + auto ¤t_row_group = row_group->GetNode(); + idx_t start_in_row_group = current_row - row_group->GetRowStart(); + idx_t append_count = MinValue(current_row_group.count - start_in_row_group, remaining); - row_group->CommitAppend(commit_id, start_in_row_group, append_count); + current_row_group.CommitAppend(commit_id, start_in_row_group, append_count); current_row += append_count; remaining -= append_count; if (remaining == 0) { break; } - row_group = row_groups->GetNextSegment(row_group); + row_group = row_groups->GetNextSegment(*row_group); } } void RowGroupCollection::RevertAppendInternal(idx_t start_row) { total_rows = start_row; + auto row_groups = GetRowGroups(); auto l = row_groups->Lock(); idx_t segment_count = row_groups->GetSegmentCount(l); @@ -502,40 +560,49 @@ void RowGroupCollection::RevertAppendInternal(idx_t start_row) { segment_index = segment_count - 1; } auto &segment = *row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); - if (segment.start == start_row) { + if (segment.GetRowStart() == start_row) { // we are truncating exactly this row group - erase it entirely row_groups->EraseSegments(l, segment_index); + + if (segment_index > 0) { + // if we have a previous segment, we need to update the next pointer + auto previous_segment = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index - 1)); + previous_segment->SetNext(nullptr); + } } else { // we need to truncate within a row group // remove any segments AFTER this segment: they should be deleted entirely row_groups->EraseSegments(l, segment_index + 1); - segment.next = nullptr; - segment.RevertAppend(start_row); + segment.SetNext(nullptr); + segment.GetNode().RevertAppend(start_row - segment.GetRowStart()); } } void RowGroupCollection::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_t count) { + auto row_groups = GetRowGroups(); auto row_group = row_groups->GetSegment(start); D_ASSERT(row_group); idx_t current_row = start; idx_t remaining = count; while (true) { - idx_t start_in_row_group = current_row - row_group->start; - idx_t append_count = MinValue(row_group->count - start_in_row_group, remaining); + auto ¤t_row_group = row_group->GetNode(); + idx_t start_in_row_group = current_row - row_group->GetRowStart(); + idx_t append_count = MinValue(current_row_group.count - start_in_row_group, remaining); - row_group->CleanupAppend(lowest_transaction, start_in_row_group, append_count); + current_row_group.CleanupAppend(lowest_transaction, start_in_row_group, append_count); current_row += append_count; remaining -= append_count; if (remaining == 0) { break; } - row_group = row_groups->GetNextSegment(row_group); + row_group = row_groups->GetNextSegment(*row_group); } } bool RowGroupCollection::IsPersistent() const { + auto row_groups = GetRowGroups(); for (auto &row_group : row_groups->Segments()) { if (!row_group.IsPersistent()) { return false; @@ -547,9 +614,10 @@ bool RowGroupCollection::IsPersistent() const { void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptr table, optional_ptr commit_state) { D_ASSERT(data.types == types); - auto start_index = row_start + total_rows.load(); + auto segments = data.GetRowGroups()->MoveSegments(); + auto row_groups = GetRowGroups(); + auto start_index = row_groups->GetBaseRowId() + total_rows.load(); auto index = start_index; - auto segments = data.row_groups->MoveSegments(); // check if the row groups we are merging are optimistically written // if all row groups are optimistically written we keep around the block pointers @@ -557,7 +625,7 @@ void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptrGetNode(); if (!row_group.IsPersistent()) { break; } @@ -568,12 +636,12 @@ void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptrMoveToCollection(*this, index); + auto row_group = entry->MoveNode(); + row_group->MoveToCollection(*this); if (commit_state && (index - start_index) < optimistically_written_count) { // serialize the block pointers of this row group - auto persistent_data = row_group->SerializeRowGroupInfo(); + auto persistent_data = row_group->SerializeRowGroupInfo(index); persistent_data.types = types; row_group_data->row_group_data.push_back(std::move(persistent_data)); } @@ -598,22 +666,27 @@ idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, // usually all (or many) ids belong to the same row group // we iterate over the ids and check for every id if it belongs to the same row group as their predecessor idx_t pos = 0; + auto row_groups = GetRowGroups(); do { idx_t start = pos; auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[start])); + + auto ¤t_row_group = row_group->GetNode(); + auto row_start = row_group->GetRowStart(); + auto row_end = row_start + current_row_group.count; for (pos++; pos < count; pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this row group - if (idx_t(ids[pos]) < row_group->start) { + if (idx_t(ids[pos]) < row_start) { // id is before row_group start -> it does not break; } - if (idx_t(ids[pos]) >= row_group->start + row_group->count) { + if (idx_t(ids[pos]) >= row_end) { // id is after row group end -> it does not break; } } - delete_count += row_group->Delete(transaction, table, ids + start, pos - start); + delete_count += current_row_group.Delete(transaction, table, ids + start, pos - start, row_start); } while (pos < count); return delete_count; @@ -622,14 +695,16 @@ idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, //===--------------------------------------------------------------------===// // Update //===--------------------------------------------------------------------===// -optional_ptr RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t &pos, idx_t count) const { - auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[pos])); - - row_t base_id = - UnsafeNumericCast(row_group->start + ((UnsafeNumericCast(ids[pos]) - row_group->start) / - STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE)); +optional_ptr> RowGroupCollection::NextUpdateRowGroup(RowGroupSegmentTree &row_groups, row_t *ids, + idx_t &pos, idx_t count) const { + auto row_group = row_groups.GetSegment(UnsafeNumericCast(ids[pos])); + + auto ¤t_row_group = row_group->GetNode(); + auto row_start = row_group->GetRowStart(); + row_t base_id = UnsafeNumericCast( + row_start + ((UnsafeNumericCast(ids[pos]) - row_start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE)); auto max_id = - MinValue(base_id + STANDARD_VECTOR_SIZE, UnsafeNumericCast(row_group->start + row_group->count)); + MinValue(base_id + STANDARD_VECTOR_SIZE, UnsafeNumericCast(row_start + current_row_group.count)); for (pos++; pos < count; pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this vector in this row group @@ -645,34 +720,88 @@ optional_ptr RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t return row_group; } -void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, - DataChunk &updates) { +void RowGroupCollection::Update(TransactionData transaction, DataTable &data_table, row_t *ids, + const vector &column_ids, DataChunk &updates) { D_ASSERT(updates.size() >= 1); idx_t pos = 0; + auto row_groups = GetRowGroups(); do { idx_t start = pos; - auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->Update(transaction, updates, ids, start, pos - start, column_ids); + auto row_group = NextUpdateRowGroup(*row_groups, ids, pos, updates.size()); + + auto ¤t_row_group = row_group->GetNode(); + current_row_group.Update(transaction, data_table, updates, ids, start, pos - start, column_ids, + row_group->GetRowStart()); auto l = stats.GetLock(); for (idx_t i = 0; i < column_ids.size(); i++) { auto column_id = column_ids[i]; - stats.MergeStats(*l, column_id.index, *row_group->GetStatistics(column_id.index)); + stats.MergeStats(*l, column_id.index, *current_row_group.GetStatistics(column_id.index)); } } while (pos < updates.size()); } -void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count) { +void GetIndexRemovalTargets(IndexEntry &entry, IndexRemovalType removal_type, optional_ptr &append_target, + optional_ptr &remove_target) { + auto &main_index = entry.index->Cast(); + + // not all indexes require delta indexes - this is tracked through BoundIndex::RequiresTransactionality + // if an index does not require this we skip creating to and appending to "deleted_rows_in_use" + bool index_requires_delta = main_index.RequiresTransactionality(); + + switch (removal_type) { + case IndexRemovalType::MAIN_INDEX_ONLY: + // directly remove from main index without appending to delta indexes + remove_target = main_index; + break; + case IndexRemovalType::REVERT_MAIN_INDEX_ONLY: + // revert main index only append - just add back to index + append_target = main_index; + break; + case IndexRemovalType::MAIN_INDEX: + // regular removal from main index - add rows to delta index if required + if (index_requires_delta) { + if (!entry.deleted_rows_in_use) { + // create "deleted_rows_in_use" if it does not exist yet + entry.deleted_rows_in_use = + main_index.CreateEmptyCopy("deleted_rows_in_use_", IndexConstraintType::NONE); + } + append_target = entry.deleted_rows_in_use; + } + remove_target = main_index; + break; + case IndexRemovalType::REVERT_MAIN_INDEX: + // revert regular append to main index - remove from deleted_rows_in_use if we appended there before + append_target = main_index; + if (index_requires_delta) { + remove_target = entry.deleted_rows_in_use; + } + break; + case IndexRemovalType::DELETED_ROWS_IN_USE: + // remove from removal index if we appended any rows + if (index_requires_delta) { + remove_target = entry.deleted_rows_in_use; + } + break; + default: + throw InternalException("Unsupported IndexRemovalType"); + } +} + +void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, + Vector &row_identifiers, idx_t count, IndexRemovalType removal_type) { auto row_ids = FlatVector::GetData(row_identifiers); - // Collect all indexed columns. + // Collect all Indexed columns on the table. unordered_set indexed_column_id_set; indexes.Scan([&](Index &index) { - D_ASSERT(index.IsBound()); auto &set = index.GetColumnIdSet(); indexed_column_id_set.insert(set.begin(), set.end()); return false; }); + + // If we are in WAL replay, delete data will be buffered, and so we sort the column_ids + // since the sorted form will be the mapping used to get back physical IDs from the buffered index chunk. vector column_ids; for (auto &col : indexed_column_id_set) { column_ids.emplace_back(col); @@ -683,13 +812,14 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ for (auto &col : column_ids) { column_types.push_back(types[col.GetPrimaryIndex()]); } + auto row_groups = GetRowGroups(); // Initialize the fetch state. Only use indexed columns. TableScanState state; - state.Initialize(std::move(column_ids)); - state.table_state.max_row = row_start + total_rows; + auto column_ids_copy = column_ids; + state.Initialize(std::move(column_ids_copy)); + state.table_state.max_row = row_groups->GetBaseRowId() + total_rows; - // Used for scanning data. Only contains the indexed columns. DataChunk fetch_chunk; fetch_chunk.Initialize(GetAllocator(), column_types); @@ -713,13 +843,16 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ // Figure out which row_group to fetch from. auto row_id = row_ids[r]; auto row_group = row_groups->GetSegment(UnsafeNumericCast(row_id)); - auto row_group_vector_idx = (UnsafeNumericCast(row_id) - row_group->start) / STANDARD_VECTOR_SIZE; - auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; + + auto ¤t_row_group = row_group->GetNode(); + auto row_start = row_group->GetRowStart(); + auto row_group_vector_idx = (UnsafeNumericCast(row_id) - row_start) / STANDARD_VECTOR_SIZE; + auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_start; // Fetch the current vector into fetch_chunk. - state.table_state.Initialize(GetTypes()); - row_group->InitializeScanWithOffset(state.table_state, row_group_vector_idx); - row_group->ScanCommitted(state.table_state, fetch_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); + state.table_state.Initialize(context, GetTypes()); + current_row_group.InitializeScanWithOffset(state.table_state, *row_group, row_group_vector_idx); + current_row_group.ScanCommitted(state.table_state, fetch_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); fetch_chunk.Verify(); // Check for any remaining row ids, if they also fall into this vector. @@ -749,34 +882,64 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ result_chunk.SetCardinality(fetch_chunk); // Slice the vector with all rows that are present in this vector. - // Then, erase all values from the indexes. + // If the index is bound, delete the data. If unbound, buffer into unbound_index. result_chunk.Slice(sel, sel_count); - indexes.Scan([&](Index &index) { + indexes.ScanEntries([&](IndexEntry &entry) { + auto &index = *entry.index; if (index.IsBound()) { - index.Cast().Delete(result_chunk, row_identifiers); + lock_guard guard(entry.lock); + // check which indexes we should append to or remove from + // note that this method might also involve appending to indexes + // the reason for that is that we have "delta" indexes that we must fill with data we are removing + // OR because we are actually reverting a previous removal + optional_ptr append_target, remove_target; + GetIndexRemovalTargets(entry, removal_type, append_target, remove_target); + + // perform the targeted append / removal + if (append_target) { + IndexAppendInfo append_info; + auto error = append_target->Append(result_chunk, row_identifiers, append_info); + if (error.HasError()) { + throw InternalException("Failed to append to %s: %s", append_target->name, error.Message()); + } + } + if (remove_target) { + remove_target->Delete(result_chunk, row_identifiers); + } return false; } - throw MissingExtensionException( - "Cannot delete from index '%s', unknown index type '%s'. You need to load the " - "extension that provides this index type before table '%s' can be modified.", - index.GetIndexName(), index.GetIndexType(), info->GetTableName()); + // Buffering takes only the indexed columns in ordering of the column_ids mapping. + DataChunk index_column_chunk; + index_column_chunk.InitializeEmpty(column_types); + for (idx_t i = 0; i < column_types.size(); i++) { + auto col_id = column_ids[i].GetPrimaryIndex(); + index_column_chunk.data[i].Reference(result_chunk.data[col_id]); + } + index_column_chunk.SetCardinality(result_chunk.size()); + auto &unbound_index = index.Cast(); + unbound_index.BufferChunk(index_column_chunk, row_identifiers, column_ids, BufferedIndexReplay::DEL_ENTRY); + return false; }); } } -void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates) { +void RowGroupCollection::UpdateColumn(TransactionData transaction, DataTable &data_table, Vector &row_ids, + const vector &column_path, DataChunk &updates) { D_ASSERT(updates.size() >= 1); auto ids = FlatVector::GetData(row_ids); idx_t pos = 0; + auto row_groups = GetRowGroups(); do { idx_t start = pos; - auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->UpdateColumn(transaction, updates, row_ids, start, pos - start, column_path); + auto row_group = NextUpdateRowGroup(*row_groups, ids, pos, updates.size()); + auto ¤t_row_group = row_group->GetNode(); + current_row_group.UpdateColumn(transaction, data_table, updates, row_ids, start, pos - start, column_path, + row_group->GetRowStart()); auto lock = stats.GetLock(); auto primary_column_idx = column_path[0]; - row_group->MergeIntoStatistics(primary_column_idx, stats.GetStats(*lock, primary_column_idx).Statistics()); + current_row_group.MergeIntoStatistics(primary_column_idx, + stats.GetStats(*lock, primary_column_idx).Statistics()); } while (pos < updates.size()); } @@ -784,22 +947,54 @@ void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_i // Checkpoint State //===--------------------------------------------------------------------===// struct CollectionCheckpointState { - CollectionCheckpointState(RowGroupCollection &collection, TableDataWriter &writer, - vector> &segments, TableStatistics &global_stats) - : collection(collection), writer(writer), executor(writer.CreateTaskExecutor()), segments(segments), - global_stats(global_stats) { - writers.resize(segments.size()); - write_data.resize(segments.size()); + CollectionCheckpointState(RowGroupCollection &collection, TableDataWriter &writer, TableStatistics &global_stats, + RowGroupSegmentTree &row_groups) + : collection(collection), writer(writer), executor(writer.CreateTaskExecutor()), global_stats(global_stats), + row_groups(row_groups) { + auto segment_count = row_groups.GetSegmentCount(); + writers.resize(segment_count); + write_data.resize(segment_count); + dropped_segments = make_uniq_array(segment_count); + overridden_segments.resize(segment_count); } RowGroupCollection &collection; TableDataWriter &writer; unique_ptr executor; - vector> &segments; vector> writers; vector write_data; TableStatistics &global_stats; - mutex write_lock; + RowGroupSegmentTree &row_groups; + + idx_t SegmentCount() const { + return writers.size(); + } + optional_ptr> GetSegment(idx_t index) { + if (overridden_segments[index]) { + return *overridden_segments[index]; + } + if (dropped_segments[index]) { + // segment was dropped + return nullptr; + } + return row_groups.GetSegmentByIndex(NumericCast(index)); + } + + void DropSegment(idx_t index) { + dropped_segments[index] = true; + } + + bool SegmentIsDropped(idx_t index) const { + return !overridden_segments[index] && dropped_segments[index]; + } + + void SetSegment(idx_t row_start, idx_t index, shared_ptr row_group) { + overridden_segments[index] = make_uniq>(row_start, std::move(row_group), index); + } + +private: + vector>> overridden_segments; + unique_array dropped_segments; }; class BaseCheckpointTask : public BaseExecutorTask { @@ -819,9 +1014,9 @@ class CheckpointTask : public BaseCheckpointTask { } void ExecuteTask() override { - auto &entry = checkpoint_state.segments[index]; - auto &row_group = *entry.node; - checkpoint_state.writers[index] = checkpoint_state.writer.GetRowGroupWriter(*entry.node); + auto entry = checkpoint_state.GetSegment(index); + auto &row_group = entry->GetNode(); + checkpoint_state.writers[index] = checkpoint_state.writer.GetRowGroupWriter(row_group); checkpoint_state.write_data[index] = row_group.WriteToDisk(*checkpoint_state.writers[index]); } @@ -837,18 +1032,19 @@ class CheckpointTask : public BaseCheckpointTask { // Vacuum //===--------------------------------------------------------------------===// struct VacuumState { - bool can_vacuum_deletes = false; + bool can_vacuum_deletes = true; + bool can_change_row_ids = false; idx_t row_start = 0; idx_t next_vacuum_idx = 0; - vector row_group_counts; + vector row_group_counts; }; class VacuumTask : public BaseCheckpointTask { public: VacuumTask(CollectionCheckpointState &checkpoint_state, VacuumState &vacuum_state, idx_t segment_idx, - idx_t merge_count, idx_t target_count, idx_t merge_rows, idx_t row_start) + idx_t merge_count, idx_t target_count, idx_t merge_rows) : BaseCheckpointTask(checkpoint_state), vacuum_state(vacuum_state), segment_idx(segment_idx), - merge_count(merge_count), target_count(target_count), merge_rows(merge_rows), row_start(row_start) { + merge_count(merge_count), target_count(target_count), merge_rows(merge_rows) { } void ExecuteTask() override { @@ -859,16 +1055,14 @@ class VacuumTask : public BaseCheckpointTask { vector> new_row_groups; vector append_counts; idx_t row_group_rows = merge_rows; - idx_t start = row_start; for (idx_t target_idx = 0; target_idx < target_count; target_idx++) { idx_t current_row_group_rows = MinValue(row_group_rows, row_group_size); - auto new_row_group = make_uniq(collection, start, current_row_group_rows); - new_row_group->InitializeEmpty(types); + auto new_row_group = make_uniq(collection, current_row_group_rows); + new_row_group->InitializeEmpty(types, ColumnDataType::MAIN_TABLE); new_row_groups.push_back(std::move(new_row_group)); append_counts.push_back(0); row_group_rows -= current_row_group_rows; - start += current_row_group_rows; } DataChunk scan_chunk; @@ -887,19 +1081,24 @@ class VacuumTask : public BaseCheckpointTask { TableScanState scan_state; scan_state.Initialize(column_ids); - scan_state.table_state.Initialize(types); + scan_state.table_state.Initialize(QueryContext(), types); scan_state.table_state.max_row = idx_t(-1); idx_t merged_groups = 0; idx_t total_row_groups = vacuum_state.row_group_counts.size(); + optional_idx row_start; for (idx_t c_idx = segment_idx; merged_groups < merge_count && c_idx < total_row_groups; c_idx++) { if (vacuum_state.row_group_counts[c_idx] == 0) { continue; } merged_groups++; - auto ¤t_row_group = *checkpoint_state.segments[c_idx].node; + auto current_segment = checkpoint_state.GetSegment(c_idx); + if (!row_start.IsValid()) { + row_start = current_segment->GetRowStart(); + } + auto ¤t_row_group = current_segment->GetNode(); - current_row_group.InitializeScan(scan_state.table_state); + current_row_group.InitializeScan(scan_state.table_state, *current_segment); while (true) { scan_chunk.Reset(); @@ -929,7 +1128,7 @@ class VacuumTask : public BaseCheckpointTask { } // drop the row group after merging current_row_group.CommitDrop(); - checkpoint_state.segments[c_idx].node.reset(); + checkpoint_state.DropSegment(c_idx); } idx_t total_append_count = 0; for (idx_t target_idx = 0; target_idx < target_count; target_idx++) { @@ -937,7 +1136,8 @@ class VacuumTask : public BaseCheckpointTask { row_group->Verify(); // assign the new row group to the current segment - checkpoint_state.segments[segment_idx + target_idx].node = std::move(row_group); + checkpoint_state.SetSegment(row_start.GetIndex() + total_append_count, segment_idx + target_idx, + std::move(row_group)); total_append_count += append_counts[target_idx]; } if (total_append_count != merge_rows) { @@ -960,30 +1160,71 @@ class VacuumTask : public BaseCheckpointTask { idx_t merge_count; idx_t target_count; idx_t merge_rows; - idx_t row_start; }; -void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state, - vector> &segments) { - auto checkpoint_type = checkpoint_state.writer.GetCheckpointType(); - bool vacuum_is_allowed = checkpoint_type != CheckpointType::CONCURRENT_CHECKPOINT; - // currently we can only vacuum deletes if we are doing a full checkpoint and there are no indexes - state.can_vacuum_deletes = info->GetIndexes().Empty() && vacuum_is_allowed; +void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state) { + auto options = checkpoint_state.writer.GetCheckpointOptions(); + // currently we can only vacuum deletes if we are doing a full checkpoint + state.can_vacuum_deletes = options.type != CheckpointType::CONCURRENT_CHECKPOINT; if (!state.can_vacuum_deletes) { return; } + + // if there are indexes - we cannot change row-ids + // this limits what kind of vacuuming we can do + state.can_change_row_ids = info->GetIndexes().Empty(); // obtain the set of committed row counts for each row group - state.row_group_counts.reserve(segments.size()); - for (auto &entry : segments) { - auto &row_group = *entry.node; + vector committed_counts; + state.row_group_counts.reserve(checkpoint_state.SegmentCount()); + for (auto &entry : checkpoint_state.row_groups.SegmentNodes()) { + auto &row_group = entry.GetNode(); + auto should_checkpoint = row_group.ShouldCheckpointRowGroup(options.transaction_id); + if (!should_checkpoint) { + // this row group does not belong to this checkpoint - it was written by a newer commit + // don't vacuum - otherwise we might move this row group around + // which could cause the subsequent commit / clean-up to fail + state.can_vacuum_deletes = false; + return; + } auto row_group_count = row_group.GetCommittedRowCount(); + if (!state.can_change_row_ids) { + idx_t total_count = row_group.count; + committed_counts.emplace_back(row_group_count); + // we cannot change row ids and this row group has deletes + // vacuuming here would alter row ids - so skip it + if (total_count != row_group_count) { + state.row_group_counts.emplace_back(); + continue; + } + } if (row_group_count == 0) { // empty row group - we can drop it entirely row_group.CommitDrop(); - entry.node.reset(); + checkpoint_state.DropSegment(entry.GetIndex()); } state.row_group_counts.push_back(row_group_count); } + if (!state.can_change_row_ids && options.type != CheckpointType::CONCURRENT_CHECKPOINT) { + // if we cannot change row ids we might still be able to vacuum trailing deletions + // since that would not change the row-ids of any non-deleted rows + auto segment_count = state.row_group_counts.size(); + for (idx_t i = segment_count; i > 0; i--) { + auto segment_idx = i - 1; + if (!committed_counts[segment_idx].IsValid()) { + // cannot vacuum this row group + break; + } + if (committed_counts[segment_idx].GetIndex() != 0) { + // multiple rows found here - skip + break; + } + auto &entry = *checkpoint_state.row_groups.GetSegmentByIndex(NumericCast(segment_idx)); + auto &row_group = entry.GetNode(); + D_ASSERT(entry.GetIndex() == segment_idx); + row_group.CommitDrop(); + checkpoint_state.DropSegment(segment_idx); + } + } } bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoint_state, VacuumState &state, @@ -998,9 +1239,9 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // this segment is being vacuumed by a previously scheduled task return true; } - if (state.row_group_counts[segment_idx] == 0) { + if (state.row_group_counts[segment_idx].IsValid() && state.row_group_counts[segment_idx].GetIndex() == 0) { // segment was already dropped - skip - D_ASSERT(!checkpoint_state.segments[segment_idx].node); + D_ASSERT(checkpoint_state.SegmentIsDropped(segment_idx)); return false; } if (!schedule_vacuum) { @@ -1022,19 +1263,24 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi auto total_target_size = target_count * row_group_size; merge_count = 0; merge_rows = 0; - for (next_idx = segment_idx; next_idx < checkpoint_state.segments.size(); next_idx++) { - if (state.row_group_counts[next_idx] == 0) { + for (next_idx = segment_idx; next_idx < checkpoint_state.SegmentCount(); next_idx++) { + if (!state.row_group_counts[next_idx].IsValid()) { + // cannot vacuum this row group - break + break; + } + auto next_row_count = state.row_group_counts[next_idx].GetIndex(); + if (next_row_count == 0) { continue; } - if (merge_rows + state.row_group_counts[next_idx] > total_target_size) { + if (merge_rows + next_row_count > total_target_size) { // does not fit break; } // we can merge this row group together with the other row group - merge_rows += state.row_group_counts[next_idx]; + merge_rows += next_row_count; merge_count++; } - if (next_idx == checkpoint_state.segments.size()) { + if (next_idx == checkpoint_state.SegmentCount()) { // in order to prevent poor performance when performing small appends, we only merge row groups at the end // if we can reach a "target" size of twice the current size, or the max row group size // this is to prevent repeated expensive checkpoints where: @@ -1043,7 +1289,7 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // merge it with a row group with 1 row, creating a row group with 100K+2 rows // etc. This leads to constant rewriting of the original 100K rows. idx_t minimum_target = - MinValue(state.row_group_counts[segment_idx] * 2, row_group_size) * target_count; + MinValue(state.row_group_counts[segment_idx].GetIndex() * 2, row_group_size) * target_count; if (merge_rows >= STANDARD_VECTOR_SIZE && merge_rows < minimum_target) { // we haven't reached the minimum target - don't do this vacuum next_idx = segment_idx + 1; @@ -1063,8 +1309,8 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // schedule the vacuum task DUCKDB_LOG(checkpoint_state.writer.GetDatabase(), CheckpointLogType, GetAttached(), *info, segment_idx, merge_count, target_count, merge_rows, state.row_start); - auto vacuum_task = make_uniq(checkpoint_state, state, segment_idx, merge_count, target_count, - merge_rows, state.row_start); + auto vacuum_task = + make_uniq(checkpoint_state, state, segment_idx, merge_count, target_count, merge_rows); checkpoint_state.executor->ScheduleTask(std::move(vacuum_task)); // skip vacuuming by the row groups we have merged state.next_vacuum_idx = next_idx; @@ -1081,20 +1327,18 @@ unique_ptr RowGroupCollection::GetCheckpointTask(CollectionCheck } void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &global_stats) { - auto l = row_groups->Lock(); - auto segments = row_groups->MoveSegments(l); + auto row_groups = GetRowGroups(); - CollectionCheckpointState checkpoint_state(*this, writer, segments, global_stats); + CollectionCheckpointState checkpoint_state(*this, writer, global_stats, *row_groups); VacuumState vacuum_state; - InitializeVacuumState(checkpoint_state, vacuum_state, segments); + InitializeVacuumState(checkpoint_state, vacuum_state); try { // schedule tasks idx_t total_vacuum_tasks = 0; auto max_vacuum_tasks = DBConfig::GetSetting(writer.GetDatabase()); - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - auto &entry = segments[segment_idx]; + for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { auto vacuum_tasks = ScheduleVacuumTasks(checkpoint_state, vacuum_state, segment_idx, total_vacuum_tasks < max_vacuum_tasks); if (vacuum_tasks) { @@ -1102,19 +1346,23 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl total_vacuum_tasks++; continue; } - if (!entry.node) { + if (checkpoint_state.SegmentIsDropped(segment_idx)) { // row group was vacuumed/dropped - skip continue; } // schedule a checkpoint task for this row group - entry.node->MoveToCollection(*this, vacuum_state.row_start); - if (writer.GetCheckpointType() != CheckpointType::VACUUM_ONLY) { + auto entry = checkpoint_state.GetSegment(segment_idx); + auto &row_group = entry->GetNode(); + if (!RefersToSameObject(row_group.GetCollection(), *this)) { + throw InternalException("RowGroup Vacuum - row group collection of row group changed"); + } + if (writer.GetCheckpointOptions().type != CheckpointType::VACUUM_ONLY) { DUCKDB_LOG(checkpoint_state.writer.GetDatabase(), CheckpointLogType, GetAttached(), *info, segment_idx, - *entry.node); + row_group, vacuum_state.row_start); auto checkpoint_task = GetCheckpointTask(checkpoint_state, segment_idx); checkpoint_state.executor->ScheduleTask(std::move(checkpoint_task)); } - vacuum_state.row_start += entry.node->count; + vacuum_state.row_start += row_group.count; } } catch (const std::exception &e) { ErrorData error(e); @@ -1129,14 +1377,13 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl // if the table already exists on disk - check if all row groups have stayed the same if (DBConfig::GetSetting(writer.GetDatabase()) && metadata_pointer.IsValid()) { bool table_has_changes = false; - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - auto &entry = segments[segment_idx]; - if (!entry.node) { + for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { + if (checkpoint_state.SegmentIsDropped(segment_idx)) { table_has_changes = true; break; } auto &write_state = checkpoint_state.write_data[segment_idx]; - if (write_state.existing_pointers.empty()) { + if (!write_state.reuse_existing_metadata_blocks) { table_has_changes = true; break; } @@ -1146,46 +1393,222 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl // we can directly re-use the metadata pointer // mark all blocks associated with row groups as still being in-use auto &metadata_manager = writer.GetMetadataManager(); - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - auto &entry = segments[segment_idx]; - auto &row_group = *entry.node; + for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { + auto entry = checkpoint_state.GetSegment(segment_idx); + auto &row_group = entry->GetNode(); auto &write_state = checkpoint_state.write_data[segment_idx]; - metadata_manager.ClearModifiedBlocks(write_state.existing_pointers); - metadata_manager.ClearModifiedBlocks(row_group.GetDeletesPointers()); - row_groups->AppendSegment(l, std::move(entry.node)); + metadata_manager.ClearModifiedBlocks(row_group.GetColumnStartPointers()); + D_ASSERT(write_state.reuse_existing_metadata_blocks); + vector extra_metadata_block_pointers; + extra_metadata_block_pointers.reserve(write_state.existing_extra_metadata_blocks.size()); + for (auto &block_pointer : write_state.existing_extra_metadata_blocks) { + extra_metadata_block_pointers.emplace_back(block_pointer, 0); + } + metadata_manager.ClearModifiedBlocks(extra_metadata_block_pointers); + row_group.CheckpointDeletes(metadata_manager); } writer.WriteUnchangedTable(metadata_pointer, total_rows.load()); + + // copy over existing stats into the global stats + CopyStats(global_stats); return; } } + // not all segments have stayed the same - we need to make a new segment tree with the new set of segments + auto new_row_groups = make_shared_ptr(*this, row_groups->GetBaseRowId()); + auto l = new_row_groups->Lock(); + + // initialize new empty stats + global_stats.InitializeEmpty(stats); + idx_t new_total_rows = 0; - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - auto &entry = segments[segment_idx]; - if (!entry.node) { + bool skipped_row_groups = false; + for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { + auto entry = checkpoint_state.GetSegment(segment_idx); + if (!entry) { // row group was vacuumed/dropped - skip + D_ASSERT(checkpoint_state.SegmentIsDropped(segment_idx)); continue; } - auto &row_group = *entry.node; - if (!checkpoint_state.writers[segment_idx]) { + auto &row_group = entry->GetNode(); + auto &row_group_writer = checkpoint_state.writers[segment_idx]; + if (!row_group_writer) { // row group was not checkpointed - this can happen if compressing is disabled for in-memory tables - D_ASSERT(writer.GetCheckpointType() == CheckpointType::VACUUM_ONLY); - row_groups->AppendSegment(l, std::move(entry.node)); + D_ASSERT(writer.GetCheckpointOptions().type == CheckpointType::VACUUM_ONLY); + new_row_groups->AppendSegment(l, entry->ReferenceNode()); new_total_rows += row_group.count; + + auto lock = global_stats.GetLock(); + for (idx_t column_idx = 0; column_idx < row_group.GetColumnCount(); column_idx++) { + global_stats.GetStats(*lock, column_idx).Statistics().Merge(*row_group.GetStatistics(column_idx)); + } continue; } - auto row_group_writer = std::move(checkpoint_state.writers[segment_idx]); - if (!row_group_writer) { - throw InternalException("Missing row group writer for index %llu", segment_idx); + auto &row_group_write_data = checkpoint_state.write_data[segment_idx]; + idx_t row_start = new_total_rows; + bool metadata_reuse = row_group_write_data.reuse_existing_metadata_blocks; + auto new_row_group = std::move(row_group_write_data.result_row_group); + if (!new_row_group) { + // row group was unchanged - emit previous row group + new_row_group = entry->ReferenceNode(); } - auto pointer = - row_group.Checkpoint(std::move(checkpoint_state.write_data[segment_idx]), *row_group_writer, global_stats); - writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); - row_groups->AppendSegment(l, std::move(entry.node)); + RowGroupPointer pointer_copy; + auto debug_verify_blocks = DBConfig::GetSetting(GetAttached().GetDatabase()) && + dynamic_cast(&checkpoint_state.writer) != nullptr; + + // check if we should write this row group to the persistent storage + // don't write it if it only has uncommitted transaction-local changes made AFTER this checkpoint was started + if (row_group_write_data.should_checkpoint) { + if (skipped_row_groups) { + throw InternalException("Checkpoint failure - we are writing a row group AFTER we skipped writing a " + "row group due to concurrent insertions. This will change the row-ids of the " + "written row groups which can cause subtle issues later."); + } + auto pointer = + row_group.Checkpoint(std::move(row_group_write_data), *row_group_writer, global_stats, row_start); + + if (debug_verify_blocks) { + pointer_copy = pointer; + } + writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); + } else { + debug_verify_blocks = false; + skipped_row_groups = true; + } + new_row_groups->AppendSegment(l, std::move(new_row_group)); new_total_rows += row_group.count; + + if (debug_verify_blocks) { + if (!pointer_copy.has_metadata_blocks) { + throw InternalException("Checkpointing should always remember metadata blocks"); + } + if (metadata_reuse && pointer_copy.data_pointers != row_group.GetColumnStartPointers()) { + throw InternalException("Colum start pointers changed during metadata reuse"); + } + + // Capture blocks that have been written + vector all_written_blocks = pointer_copy.data_pointers; + vector all_metadata_blocks; + for (auto &block : pointer_copy.extra_metadata_blocks) { + all_written_blocks.emplace_back(block, 0); + all_metadata_blocks.emplace_back(block, 0); + } + + // Verify that we can load the metadata correctly again + vector all_quick_read_blocks; + for (auto &ptr : row_group.GetColumnStartPointers()) { + all_quick_read_blocks.emplace_back(ptr); + if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(ptr)) { + throw InternalException("Found column start block that was not cleared"); + } + } + auto extra_metadata_blocks = row_group.GetOrComputeExtraMetadataBlocks(/* force_compute: */ true); + for (auto &ptr : extra_metadata_blocks) { + auto block_pointer = MetaBlockPointer(ptr, 0); + all_quick_read_blocks.emplace_back(block_pointer); + if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(block_pointer)) { + throw InternalException("Found extra metadata block that was not cleared"); + } + } + + // Deserialize all columns to check if the quick read via GetOrComputeExtraMetadataBlocks was correct + vector all_full_read_blocks; + auto column_start_pointers = row_group.GetColumnStartPointers(); + auto &types = row_group.GetCollection().GetTypes(); + auto &metadata_manager = row_group.GetCollection().GetMetadataManager(); + for (idx_t i = 0; i < column_start_pointers.size(); i++) { + MetadataReader reader(metadata_manager, column_start_pointers[i], &all_full_read_blocks); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, reader, types[i]); + } + + // Derive sets of blocks to compare + set all_written_block_ids; + for (auto &ptr : all_written_blocks) { + all_written_block_ids.insert(ptr.block_pointer); + } + set all_quick_read_block_ids; + for (auto &ptr : all_quick_read_blocks) { + all_quick_read_block_ids.insert(ptr.block_pointer); + } + set all_full_read_block_ids; + for (auto &ptr : all_full_read_blocks) { + all_full_read_block_ids.insert(ptr.block_pointer); + } + if (all_written_block_ids != all_quick_read_block_ids || + all_quick_read_block_ids != all_full_read_block_ids) { + std::stringstream oss; + oss << "Written: "; + for (auto &block : all_written_blocks) { + oss << block << ", "; + } + oss << "\n"; + oss << "Quick read: "; + for (auto &block : all_quick_read_blocks) { + oss << block << ", "; + } + oss << "\n"; + oss << "Full read: "; + for (auto &block : all_full_read_blocks) { + oss << block << ", "; + } + oss << "\n"; + + throw InternalException("Reloading blocks just written does not yield same blocks: " + oss.str()); + } + + vector read_deletes_pointers; + if (!pointer_copy.deletes_pointers.empty()) { + auto root_delete = pointer_copy.deletes_pointers[0]; + auto vm = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager()); + read_deletes_pointers = vm->GetStoragePointers(); + } + + set all_written_deletes_block_ids; + for (auto &ptr : pointer_copy.deletes_pointers) { + all_written_deletes_block_ids.insert(ptr.block_pointer); + } + set all_read_deletes_block_ids; + for (auto &ptr : read_deletes_pointers) { + all_read_deletes_block_ids.insert(ptr.block_pointer); + } + + if (all_written_deletes_block_ids != all_read_deletes_block_ids) { + std::stringstream oss; + oss << "Written: "; + for (auto &block : all_written_deletes_block_ids) { + oss << block << ", "; + } + oss << "\n"; + oss << "Read: "; + for (auto &block : all_read_deletes_block_ids) { + oss << block << ", "; + } + oss << "\n"; + + throw InternalException("Reloading deletes blocks just written does not yield same blocks: " + + oss.str()); + } + } } - total_rows = new_total_rows; l.Release(); + + if (skipped_row_groups) { + // if we skipped any rows groups we cannot override the base stats + // because the stats reflect only the *checkpointed* row groups + // the stats of the extra (not checkpointed) row groups is not included + // hence the stats do not correctly reflect the current in-memory state of the table + writer.SetCannotOverrideStats(); + } + + // flush any partial blocks BEFORE updating the row group pointer + // flushing partial blocks updates where data lives + // this cannot be done after other threads start scanning the row groups + // so this HAS to happen before we call "SetRowGroups" to update the row groups + writer.FlushPartialBlocks(); + // override the row group segment tree + total_rows = new_total_rows; + SetRowGroups(std::move(new_row_groups)); Verify(); } @@ -1195,7 +1618,7 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl class DestroyTask : public BaseExecutorTask { public: - DestroyTask(TaskExecutor &executor, unique_ptr row_group_p) + DestroyTask(TaskExecutor &executor, shared_ptr row_group_p) : BaseExecutorTask(executor), row_group(std::move(row_group_p)) { } @@ -1204,16 +1627,16 @@ class DestroyTask : public BaseExecutorTask { } private: - unique_ptr row_group; + shared_ptr row_group; }; void RowGroupCollection::Destroy() { - auto l = row_groups->Lock(); - auto &segments = row_groups->ReferenceLoadedSegmentsMutable(l); + auto l = owned_row_groups->Lock(); + auto &segments = owned_row_groups->ReferenceLoadedSegmentsMutable(l); TaskExecutor executor(TaskScheduler::GetScheduler(GetAttached().GetDatabase())); for (auto &segment : segments) { - auto destroy_task = make_uniq(executor, std::move(segment.node)); + auto destroy_task = make_uniq(executor, segment->MoveNode()); executor.ScheduleTask(std::move(destroy_task)); } executor.WorkOnTasks(); @@ -1223,12 +1646,14 @@ void RowGroupCollection::Destroy() { // CommitDrop //===--------------------------------------------------------------------===// void RowGroupCollection::CommitDropColumn(const idx_t column_index) { + auto row_groups = GetRowGroups(); for (auto &row_group : row_groups->Segments()) { row_group.CommitDropColumn(column_index); } } void RowGroupCollection::CommitDropTable() { + auto row_groups = GetRowGroups(); for (auto &row_group : row_groups->Segments()) { row_group.CommitDrop(); } @@ -1239,8 +1664,10 @@ void RowGroupCollection::CommitDropTable() { //===--------------------------------------------------------------------===// vector RowGroupCollection::GetPartitionStats() const { vector result; - for (auto &row_group : row_groups->Segments()) { - result.push_back(row_group.GetPartitionStats()); + auto row_groups = GetRowGroups(); + for (auto &entry : row_groups->SegmentNodes()) { + auto &row_group = entry.GetNode(); + result.push_back(row_group.GetPartitionStats(entry.GetRowStart())); } return result; } @@ -1248,11 +1675,13 @@ vector RowGroupCollection::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -vector RowGroupCollection::GetColumnSegmentInfo() { +vector RowGroupCollection::GetColumnSegmentInfo(const QueryContext &context) { vector result; + auto row_groups = GetRowGroups(); auto lock = row_groups->Lock(); - for (auto &row_group : row_groups->Segments(lock)) { - row_group.GetColumnSegmentInfo(row_group.index, result); + for (auto &node : row_groups->SegmentNodes(lock)) { + auto &row_group = node.GetNode(); + row_group.GetColumnSegmentInfo(context, node.GetIndex(), result); } return result; } @@ -1264,9 +1693,10 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont ExpressionExecutor &default_executor) { idx_t new_column_idx = types.size(); auto new_types = types; + auto row_groups = GetRowGroups(); new_types.push_back(new_column.GetType()); - auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_start, - total_rows.load(), row_group_size); + auto result = make_shared_ptr(info, block_manager, std::move(new_types), + row_groups->GetBaseRowId(), total_rows.load(), row_group_size); DataChunk dummy_chunk; Vector default_vector(new_column.GetType()); @@ -1277,12 +1707,13 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont // fill the column with its DEFAULT value, or NULL if none is specified auto new_stats = make_uniq(new_column.GetType()); + auto result_row_groups = result->GetRowGroups(); for (auto ¤t_row_group : row_groups->Segments()) { auto new_row_group = current_row_group.AddColumn(*result, new_column, default_executor, default_vector); // merge in the statistics new_row_group->MergeIntoStatistics(new_column_idx, new_column_stats.Statistics()); - result->row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group)); } return result; @@ -1291,18 +1722,20 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { D_ASSERT(col_idx < types.size()); auto new_types = types; + auto row_groups = GetRowGroups(); new_types.erase_at(col_idx); - auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_start, - total_rows.load(), row_group_size); + auto result = make_shared_ptr(info, block_manager, std::move(new_types), + row_groups->GetBaseRowId(), total_rows.load(), row_group_size); result->stats.InitializeRemoveColumn(stats, col_idx); auto result_lock = result->stats.GetLock(); result->stats.DestroyTableSample(*result_lock); + auto result_row_groups = result->GetRowGroups(); for (auto ¤t_row_group : row_groups->Segments()) { auto new_row_group = current_row_group.RemoveColumn(*result, col_idx); - result->row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group)); } return result; } @@ -1313,10 +1746,11 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont Expression &cast_expr) { D_ASSERT(changed_idx < types.size()); auto new_types = types; + auto row_groups = GetRowGroups(); new_types[changed_idx] = target_type; - auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_start, - total_rows.load(), row_group_size); + auto result = make_shared_ptr(info, block_manager, std::move(new_types), + row_groups->GetBaseRowId(), total_rows.load(), row_group_size); result->stats.InitializeAlterType(stats, changed_idx, target_type); vector scan_types; @@ -1335,21 +1769,26 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont TableScanState scan_state; scan_state.Initialize(bound_columns); - scan_state.table_state.max_row = row_start + total_rows; + scan_state.table_state.Initialize(context, GetTypes()); + scan_state.table_state.max_row = row_groups->GetBaseRowId() + total_rows; // now alter the type of the column within all of the row_groups individually auto lock = result->stats.GetLock(); auto &changed_stats = result->stats.GetStats(*lock, changed_idx); - for (auto ¤t_row_group : row_groups->Segments()) { + auto result_row_groups = result->GetRowGroups(); + + for (auto &node : row_groups->SegmentNodes()) { + auto ¤t_row_group = node.GetNode(); auto new_row_group = current_row_group.AlterType(*result, target_type, changed_idx, executor, - scan_state.table_state, scan_chunk); + scan_state.table_state, node, scan_chunk); new_row_group->MergeIntoStatistics(changed_idx, changed_stats.Statistics()); - result->row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group)); } return result; } -void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { +void RowGroupCollection::VerifyNewConstraint(const QueryContext &context, DataTable &parent, + const BoundConstraint &constraint) { if (total_rows == 0) { return; } @@ -1371,7 +1810,7 @@ void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConst CreateIndexScanState state; auto scan_type = TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; state.Initialize(column_ids, nullptr); - InitializeScan(state.table_state, column_ids, nullptr); + InitializeScan(context, state.table_state, column_ids, nullptr); InitializeCreateIndexScan(state); @@ -1393,6 +1832,11 @@ void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConst //===--------------------------------------------------------------------===// // Statistics //===---------------------------------------------------------------r-----===// + +void RowGroupCollection::SetStats(TableStatistics &new_stats) { + stats.SetStats(new_stats); +} + void RowGroupCollection::CopyStats(TableStatistics &other_stats) { stats.CopyStats(other_stats); } diff --git a/src/duckdb/src/storage/table/row_group_reorderer.cpp b/src/duckdb/src/storage/table/row_group_reorderer.cpp new file mode 100644 index 000000000..1c862753b --- /dev/null +++ b/src/duckdb/src/storage/table/row_group_reorderer.cpp @@ -0,0 +1,257 @@ +#include "duckdb/storage/table/row_group_reorderer.hpp" + +namespace duckdb { + +namespace { + +struct RowGroupSegmentNodeEntry { + reference> row_group; + unique_ptr stats; +}; + +struct RowGroupOffsetEntry { + idx_t count; + unique_ptr stats; +}; + +bool CompareValues(const Value &v1, const Value &v2, const OrderByStatistics order) { + return (order == OrderByStatistics::MAX && v1 < v2) || (order == OrderByStatistics::MIN && v1 > v2); +} + +idx_t GetQualifyingTupleCount(RowGroup &row_group, BaseStatistics &stats, const OrderByColumnType type) { + if (!stats.CanHaveNull()) { + return row_group.count; + } + + if (type == OrderByColumnType::NUMERIC) { + if (!NumericStats::HasMinMax(stats)) { + return 0; + } + if (NumericStats::IsConstant(stats)) { + return 1; + } + return 2; + } + // We cannot check if the min/max for StringStats have actually been set. As the strings may be truncated, we + // also cannot assume that min and max are the same + return 0; +} + +template +void AddRowGroups(multimap &row_group_map, It it, End end, + vector>> &ordered_row_groups, const idx_t row_limit, + const OrderByColumnType column_type, const OrderByStatistics stat_type) { + const auto opposite_stat_type = + stat_type == OrderByStatistics::MAX ? OrderByStatistics::MIN : OrderByStatistics::MAX; + + auto last_unresolved_entry = it; + auto &last_stats = it->second.stats; + auto last_unresolved_boundary = RowGroupReorderer::RetrieveStat(*last_stats, opposite_stat_type, column_type); + + // Try to find row groups that can be excluded with limit + idx_t qualifying_tuples = 0; + idx_t qualify_later = 0; + + idx_t last_unresolved_row_group_sum = + GetQualifyingTupleCount(it->second.row_group.get().GetNode(), *last_stats, column_type); + for (; it != end; ++it) { + auto ¤t_key = it->first; + auto &row_group = it->second.row_group; + + while (last_unresolved_entry != it) { + if (!CompareValues(current_key, last_unresolved_boundary, stat_type)) { + if (current_key != std::prev(it)->first) { + // Row groups overlap: we can only guarantee one additional qualifying tuple + qualifying_tuples += qualify_later; + qualify_later = 0; + qualifying_tuples++; + } else { + // Row groups have the same order value, we can only guarantee a qualifying tuple later + qualify_later++; + } + + break; + } + // Row groups do not overlap: we can guarantee that the tuples qualify + qualifying_tuples = last_unresolved_row_group_sum; + ++last_unresolved_entry; + auto &upcoming_row_group = last_unresolved_entry->second.row_group.get().GetNode(); + auto &upcoming_stats = *last_unresolved_entry->second.stats; + + last_unresolved_row_group_sum += GetQualifyingTupleCount(upcoming_row_group, upcoming_stats, column_type); + last_unresolved_boundary = RowGroupReorderer::RetrieveStat(upcoming_stats, opposite_stat_type, column_type); + } + if (qualifying_tuples >= row_limit) { + return; + } + ordered_row_groups.emplace_back(row_group); + } +} + +template +It SkipOffsetPrunedRowGroups(It it, const idx_t row_group_offset) { + for (idx_t i = 0; i < row_group_offset; i++) { + ++it; + } + return it; +} + +template +void InsertAllRowGroups(It it, End end, vector>> &ordered_row_groups) { + for (; it != end; ++it) { + ordered_row_groups.push_back(it->second.row_group); + } +} + +void SetRowGroupVector(multimap &row_group_map, const optional_idx row_limit, + const idx_t row_group_offset, const RowGroupOrderType order_type, + const OrderByColumnType column_type, + vector>> &ordered_row_groups) { + const auto stat_type = order_type == RowGroupOrderType::ASC ? OrderByStatistics::MIN : OrderByStatistics::MAX; + ordered_row_groups.reserve(row_group_map.size()); + + Value previous_key; + if (order_type == RowGroupOrderType::ASC) { + auto it = SkipOffsetPrunedRowGroups(row_group_map.begin(), row_group_offset); + auto end = row_group_map.end(); + if (row_limit.IsValid()) { + AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, stat_type); + } else { + InsertAllRowGroups(it, end, ordered_row_groups); + } + } else { + auto it = SkipOffsetPrunedRowGroups(row_group_map.rbegin(), row_group_offset); + auto end = row_group_map.rend(); + if (row_limit.IsValid()) { + AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, stat_type); + } else { + InsertAllRowGroups(it, end, ordered_row_groups); + } + } +} + +template +OffsetPruningResult FindOffsetPrunableChunks(It it, End end, const OrderByStatistics order_by, + const OrderByColumnType column_type, const idx_t row_offset) { + const auto opposite_stat_type = + order_by == OrderByStatistics::MAX ? OrderByStatistics::MIN : OrderByStatistics::MAX; + + auto last_unresolved_entry = it; + auto last_unresolved_boundary = RowGroupReorderer::RetrieveStat(*it->second.stats, opposite_stat_type, column_type); + + // Try to find row groups that can be excluded with offset + idx_t seen_tuples = 0; + idx_t new_row_offset = row_offset; + idx_t pruned_row_group_count = 0; + + for (; it != end; ++it) { + auto ¤t_key = it->first; + auto tuple_count = it->second.count; + seen_tuples += tuple_count; + + while (last_unresolved_entry != it) { + if (!CompareValues(current_key, last_unresolved_boundary, order_by)) { + break; + } + // Row groups do not overlap + auto ¤t_stats = it->second.stats; + if (!current_stats->CanHaveNull()) { + // This row group has exactly row_group.count valid values. We can exclude those + pruned_row_group_count++; + new_row_offset -= tuple_count; + } + + ++last_unresolved_entry; + auto &upcoming_stats = *last_unresolved_entry->second.stats; + last_unresolved_boundary = RowGroupReorderer::RetrieveStat(upcoming_stats, opposite_stat_type, column_type); + } + + if (seen_tuples > row_offset) { + break; + } + } + + return {new_row_offset, pruned_row_group_count}; +} + +} // namespace + +RowGroupReorderer::RowGroupReorderer(const RowGroupOrderOptions &options_p) + : options(options_p), offset(0), initialized(false) { +} + +optional_ptr> RowGroupReorderer::GetNextRowGroup(SegmentNode &row_group) { + D_ASSERT(RefersToSameObject(ordered_row_groups[offset].get(), row_group)); + if (offset >= ordered_row_groups.size() - 1) { + return nullptr; + } + return ordered_row_groups[++offset].get(); +} + +Value RowGroupReorderer::RetrieveStat(const BaseStatistics &stats, OrderByStatistics order_by, + OrderByColumnType column_type) { + switch (order_by) { + case OrderByStatistics::MIN: + return column_type == OrderByColumnType::NUMERIC ? NumericStats::Min(stats) : StringStats::Min(stats); + case OrderByStatistics::MAX: + return column_type == OrderByColumnType::NUMERIC ? NumericStats::Max(stats) : StringStats::Max(stats); + } + return Value(); +} + +OffsetPruningResult RowGroupReorderer::GetOffsetAfterPruning(const OrderByStatistics order_by, + const OrderByColumnType column_type, + const RowGroupOrderType order_type, + const column_t column_idx, const idx_t row_offset, + vector &stats) { + multimap ordered_row_groups; + + for (auto &partition_stats : stats) { + if (partition_stats.count_type == CountType::COUNT_APPROXIMATE || !partition_stats.partition_row_group) { + return {row_offset, 0}; + } + + auto column_stats = partition_stats.partition_row_group->GetColumnStatistics(column_idx); + Value comparison_value = RetrieveStat(*column_stats, order_by, column_type); + auto entry = RowGroupOffsetEntry {partition_stats.count, std::move(column_stats)}; + ordered_row_groups.emplace(comparison_value, std::move(entry)); + } + + if (order_type == RowGroupOrderType::ASC) { + return FindOffsetPrunableChunks(ordered_row_groups.begin(), ordered_row_groups.end(), order_by, column_type, + row_offset); + } + return FindOffsetPrunableChunks(ordered_row_groups.rbegin(), ordered_row_groups.rend(), order_by, column_type, + row_offset); +} + +optional_ptr> RowGroupReorderer::GetRootSegment(RowGroupSegmentTree &row_groups) { + if (initialized) { + if (ordered_row_groups.empty()) { + return nullptr; + } + return ordered_row_groups[0].get(); + } + + initialized = true; + + multimap row_group_map; + for (auto &row_group : row_groups.SegmentNodes()) { + auto stats = row_group.GetNode().GetStatistics(options.column_idx); + Value comparison_value = RetrieveStat(*stats, options.order_by, options.column_type); + auto entry = RowGroupSegmentNodeEntry {row_group, std::move(stats)}; + row_group_map.emplace(comparison_value, std::move(entry)); + } + + if (row_group_map.empty()) { + return nullptr; + } + + D_ASSERT(row_group_map.size() > options.row_group_offset); + SetRowGroupVector(row_group_map, options.row_limit, options.row_group_offset, options.order_type, + options.column_type, ordered_row_groups); + + return ordered_row_groups[0].get(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_id_column_data.cpp b/src/duckdb/src/storage/table/row_id_column_data.cpp index d869913bf..413c61a97 100644 --- a/src/duckdb/src/storage/table/row_id_column_data.cpp +++ b/src/duckdb/src/storage/table/row_id_column_data.cpp @@ -4,28 +4,32 @@ namespace duckdb { -RowIdColumnData::RowIdColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t start_row) - : ColumnData(block_manager, info, COLUMN_IDENTIFIER_ROW_ID, start_row, LogicalType(LogicalTypeId::BIGINT), - nullptr) { +RowIdColumnData::RowIdColumnData(BlockManager &block_manager, DataTableInfo &info) + : ColumnData(block_manager, info, COLUMN_IDENTIFIER_ROW_ID, LogicalType(LogicalTypeId::BIGINT), + ColumnDataType::MAIN_TABLE, nullptr) { + stats->statistics.SetHasNoNullFast(); } FilterPropagateResult RowIdColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { - return RowGroup::CheckRowIdFilter(filter, start, start + count); - ; + auto row_start = state.parent->row_group->GetRowStart(); + return RowGroup::CheckRowIdFilter(filter, row_start, row_start + count); } void RowIdColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { } void RowIdColumnData::InitializeScan(ColumnScanState &state) { - InitializeScanWithOffset(state, start); + InitializeScanWithOffset(state, 0); } void RowIdColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + if (row_idx > count) { + throw InternalException("row_idx in InitializeScanWithOffset out of range"); + } state.current = nullptr; state.segment_tree = nullptr; - state.row_index = row_idx; - state.internal_index = state.row_index; + state.offset_in_column = row_idx; + state.internal_index = state.offset_in_column; state.initialized = true; state.scan_state.reset(); state.last_offset = 0; @@ -43,25 +47,26 @@ idx_t RowIdColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, void RowIdColumnData::ScanCommittedRange(idx_t row_group_start, idx_t offset_in_row_group, idx_t count, Vector &result) { - D_ASSERT(this->start == row_group_start); - result.Sequence(UnsafeNumericCast(this->start + offset_in_row_group), 1, count); + result.Sequence(UnsafeNumericCast(row_group_start + offset_in_row_group), 1, count); } idx_t RowIdColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset) { + auto row_start = state.parent->row_group->GetRowStart(); if (result_offset != 0) { throw InternalException("RowIdColumnData result_offset must be 0"); } - ScanCommittedRange(start, state.row_index - start, count, result); - state.row_index += count; + ScanCommittedRange(row_start, state.offset_in_column, count, result); + state.offset_in_column += count; return count; } void RowIdColumnData::Filter(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, SelectionVector &sel, idx_t &count, const TableFilter &filter, TableFilterState &filter_state) { - auto current_row = state.row_index; + auto row_start = state.parent->row_group->GetRowStart(); + auto current_row = row_start + state.offset_in_column; auto max_count = GetVectorCount(vector_index); - state.row_index += max_count; + state.offset_in_column += max_count; // We do another quick statistics scan for row ids here const auto rowid_start = current_row; const auto rowid_end = current_row + max_count; @@ -100,10 +105,11 @@ void RowIdColumnData::SelectCommitted(idx_t vector_index, ColumnScanState &state idx_t count, bool allow_updates) { result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::GetData(result); + auto row_start = state.parent->row_group->GetRowStart(); for (size_t sel_idx = 0; sel_idx < count; sel_idx++) { - result_data[sel_idx] = UnsafeNumericCast(state.row_index + sel.get_index(sel_idx)); + result_data[sel_idx] = UnsafeNumericCast(row_start + state.offset_in_column + sel.get_index(sel_idx)); } - state.row_index += GetVectorCount(vector_index); + state.offset_in_column += GetVectorCount(vector_index); } idx_t RowIdColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { @@ -114,11 +120,13 @@ void RowIdColumnData::FetchRow(TransactionData transaction, ColumnFetchState &st idx_t result_idx) { result.SetVectorType(VectorType::FLAT_VECTOR); auto data = FlatVector::GetData(result); - data[result_idx] = row_id; + auto row_start = state.row_group->GetRowStart(); + data[result_idx] = UnsafeNumericCast(row_start) + row_id; } + void RowIdColumnData::Skip(ColumnScanState &state, idx_t count) { - state.row_index += count; - state.internal_index = state.row_index; + state.offset_in_column += count; + state.internal_index = state.offset_in_column; } void RowIdColumnData::InitializeAppend(ColumnAppendState &state) { @@ -134,35 +142,36 @@ void RowIdColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::RevertAppend(row_t start_row) { +void RowIdColumnData::RevertAppend(row_t new_count) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void RowIdColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { throw InternalException("RowIdColumnData cannot be updated"); } -void RowIdColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void RowIdColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { throw InternalException("RowIdColumnData cannot be updated"); } -void RowIdColumnData::CommitDropColumn() { - throw InternalException("RowIdColumnData cannot be dropped"); +void RowIdColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + throw InternalException("VisitBlockIds not supported for rowid"); } -unique_ptr RowIdColumnData::CreateCheckpointState(RowGroup &row_group, +unique_ptr RowIdColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { throw InternalException("RowIdColumnData cannot be checkpointed"); } -unique_ptr RowIdColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) { +unique_ptr RowIdColumnData::Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &info) { throw InternalException("RowIdColumnData cannot be checkpointed"); } -void RowIdColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, - Vector &scan_vector) { +void RowIdColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, + Vector &scan_vector) const { throw InternalException("RowIdColumnData cannot be checkpointed"); } diff --git a/src/duckdb/src/storage/table/row_version_manager.cpp b/src/duckdb/src/storage/table/row_version_manager.cpp index df4e463da..20d0ebed4 100644 --- a/src/duckdb/src/storage/table/row_version_manager.cpp +++ b/src/duckdb/src/storage/table/row_version_manager.cpp @@ -7,19 +7,10 @@ namespace duckdb { -RowVersionManager::RowVersionManager(idx_t start) noexcept : start(start), has_changes(false) { -} - -void RowVersionManager::SetStart(idx_t new_start) { - lock_guard l(version_lock); - this->start = new_start; - idx_t current_start = start; - for (auto &info : vector_info) { - if (info) { - info->start = current_start; - } - current_start += STANDARD_VECTOR_SIZE; - } +RowVersionManager::RowVersionManager(BufferManager &buffer_manager_p) noexcept + : allocator(STANDARD_VECTOR_SIZE * sizeof(transaction_t), buffer_manager_p.GetTemporaryBlockManager(), + MemoryTag::BASE_TABLE), + has_unserialized_changes(false) { } idx_t RowVersionManager::GetCommittedDeletedCount(idx_t count) { @@ -45,6 +36,53 @@ optional_ptr RowVersionManager::GetChunkInfo(idx_t vector_idx) { return vector_info[vector_idx].get(); } +bool RowVersionManager::ShouldCheckpointRowGroup(transaction_t checkpoint_id, idx_t count) { + lock_guard l(version_lock); + TransactionData checkpoint_transaction(checkpoint_id, checkpoint_id); + + idx_t total_count = 0; + for (idx_t read_count = 0, vector_idx = 0; read_count < count; read_count += STANDARD_VECTOR_SIZE, vector_idx++) { + idx_t max_count = MinValue(count - read_count, STANDARD_VECTOR_SIZE); + idx_t checkpoint_count; + auto chunk_info = GetChunkInfo(vector_idx); + if (!chunk_info) { + checkpoint_count = max_count; + } else { + checkpoint_count = chunk_info->GetCheckpointRowCount(checkpoint_transaction, max_count); + } + if (checkpoint_count == 0) { + continue; + } + if (total_count != read_count) { + string chunk_info_text; + for (idx_t i = 0; i <= vector_idx; i++) { + auto current_info = GetChunkInfo(i); + chunk_info_text += "\n"; + chunk_info_text += to_string(i) + ": "; + if (current_info) { + chunk_info_text += current_info->ToString(max_count); + } else { + chunk_info_text += "(empty)"; + } + } + throw InternalException( + "Error in RowGroup::GetCheckpointRowCount - insertions are not sequential - at vector idx %d found %d " + "rows, where we have already obtained %d from the total %d, transaction start time %d%s", + vector_idx, checkpoint_count, total_count, read_count, checkpoint_id, chunk_info_text); + } + total_count += checkpoint_count; + } + if (total_count == 0) { + return false; + } + if (total_count != count) { + throw InternalException("RowGroup::GetCheckpointRowCount returned a partially checkpointed entry (checkpoint " + "count %d, row group count %d)", + total_count, count); + } + return true; +} + idx_t RowVersionManager::GetSelVector(TransactionData transaction, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count) { lock_guard l(version_lock); @@ -88,7 +126,7 @@ void RowVersionManager::FillVectorInfo(idx_t vector_idx) { void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, idx_t row_group_end) { lock_guard lock(version_lock); - has_changes = true; + has_unserialized_changes = true; idx_t start_vector_idx = row_group_start / STANDARD_VECTOR_SIZE; idx_t end_vector_idx = (row_group_end - 1) / STANDARD_VECTOR_SIZE; @@ -103,7 +141,7 @@ void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t cou vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; if (vector_start == 0 && vector_end == STANDARD_VECTOR_SIZE) { // entire vector is encapsulated by append: append a single constant - auto constant_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + auto constant_info = make_uniq(vector_idx * STANDARD_VECTOR_SIZE); constant_info->insert_id = transaction.transaction_id; constant_info->delete_id = NOT_DELETED_ID; vector_info[vector_idx] = std::move(constant_info); @@ -112,7 +150,7 @@ void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t cou optional_ptr new_info; if (!vector_info[vector_idx]) { // first time appending to this vector: create new info - auto insert_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + auto insert_info = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE); new_info = insert_info.get(); vector_info[vector_idx] = std::move(insert_info); } else if (vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO) { @@ -141,6 +179,7 @@ void RowVersionManager::CommitAppend(transaction_t commit_id, idx_t row_group_st idx_t vend = vector_idx == end_vector_idx ? row_group_end - end_vector_idx * STANDARD_VECTOR_SIZE : STANDARD_VECTOR_SIZE; auto &info = *vector_info[vector_idx]; + D_ASSERT(has_unserialized_changes); info.CommitAppend(commit_id, vstart, vend); } } @@ -167,18 +206,21 @@ void RowVersionManager::CleanupAppend(transaction_t lowest_active_transaction, i } auto &info = *vector_info[vector_idx]; // if we wrote the entire chunk info try to compress it - unique_ptr new_info; - auto cleanup = info.Cleanup(lowest_active_transaction, new_info); + auto cleanup = info.Cleanup(lowest_active_transaction); if (cleanup) { - vector_info[vector_idx] = std::move(new_info); + if (info.HasDeletes()) { + has_unserialized_changes = true; + } + vector_info[vector_idx].reset(); } } } -void RowVersionManager::RevertAppend(idx_t start_row) { +void RowVersionManager::RevertAppend(idx_t new_count) { lock_guard lock(version_lock); - idx_t start_vector_idx = (start_row + (STANDARD_VECTOR_SIZE - 1)) / STANDARD_VECTOR_SIZE; + idx_t start_vector_idx = (new_count + (STANDARD_VECTOR_SIZE - 1)) / STANDARD_VECTOR_SIZE; for (idx_t vector_idx = start_vector_idx; vector_idx < vector_info.size(); vector_idx++) { + D_ASSERT(has_unserialized_changes); vector_info[vector_idx].reset(); } } @@ -188,15 +230,11 @@ ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { if (!vector_info[vector_idx]) { // no info yet: create it - vector_info[vector_idx] = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); + vector_info[vector_idx] = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE); } else if (vector_info[vector_idx]->type == ChunkInfoType::CONSTANT_INFO) { auto &constant = vector_info[vector_idx]->Cast(); // info exists but it's a constant info: convert to a vector info - auto new_info = make_uniq(start + vector_idx * STANDARD_VECTOR_SIZE); - new_info->insert_id = constant.insert_id; - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - new_info->inserted[i] = constant.insert_id; - } + auto new_info = make_uniq(allocator, vector_idx * STANDARD_VECTOR_SIZE, constant.insert_id); vector_info[vector_idx] = std::move(new_info); } D_ASSERT(vector_info[vector_idx]->type == ChunkInfoType::VECTOR_INFO); @@ -205,19 +243,19 @@ ChunkVectorInfo &RowVersionManager::GetVectorInfo(idx_t vector_idx) { idx_t RowVersionManager::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t rows[], idx_t count) { lock_guard lock(version_lock); - has_changes = true; + has_unserialized_changes = true; return GetVectorInfo(vector_idx).Delete(transaction_id, rows, count); } void RowVersionManager::CommitDelete(idx_t vector_idx, transaction_t commit_id, const DeleteInfo &info) { lock_guard lock(version_lock); - has_changes = true; + has_unserialized_changes = true; GetVectorInfo(vector_idx).CommitDelete(commit_id, info); } vector RowVersionManager::Checkpoint(MetadataManager &manager) { - if (!has_changes && !storage_pointers.empty()) { - // the row version manager already exists on disk and no changes were made + lock_guard lock(version_lock); + if (!has_unserialized_changes) { // we can write the current pointer as-is // ensure the blocks we are pointing to are not marked as free manager.ClearModifiedBlocks(storage_pointers); @@ -236,33 +274,32 @@ vector RowVersionManager::Checkpoint(MetadataManager &manager) } to_serialize.emplace_back(vector_idx, *chunk_info); } - if (to_serialize.empty()) { - return vector(); - } storage_pointers.clear(); - MetadataWriter writer(manager, &storage_pointers); - // now serialize the actual version information - writer.Write(to_serialize.size()); - for (auto &entry : to_serialize) { - auto &vector_idx = entry.first; - auto &chunk_info = entry.second.get(); - writer.Write(vector_idx); - chunk_info.Write(writer); + if (!to_serialize.empty()) { + MetadataWriter writer(manager, &storage_pointers); + // now serialize the actual version information + writer.Write(to_serialize.size()); + for (auto &entry : to_serialize) { + auto &vector_idx = entry.first; + auto &chunk_info = entry.second.get(); + writer.Write(vector_idx); + chunk_info.Write(writer); + } + writer.Flush(); } - writer.Flush(); - has_changes = false; + has_unserialized_changes = false; return storage_pointers; } -shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, MetadataManager &manager, - idx_t start) { +shared_ptr RowVersionManager::Deserialize(MetaBlockPointer delete_pointer, + MetadataManager &manager) { if (!delete_pointer.IsValid()) { return nullptr; } - auto version_info = make_shared_ptr(start); + auto version_info = make_shared_ptr(manager.GetBufferManager()); MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); auto chunk_count = source.Read(); D_ASSERT(chunk_count > 0); @@ -275,10 +312,21 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de } version_info->FillVectorInfo(vector_index); - version_info->vector_info[vector_index] = ChunkInfo::Read(source); + version_info->vector_info[vector_index] = ChunkInfo::Read(version_info->GetAllocator(), source); } - version_info->has_changes = false; + version_info->has_unserialized_changes = false; return version_info; } +bool RowVersionManager::HasUnserializedChanges() { + lock_guard lock(version_lock); + return has_unserialized_changes; +} + +vector RowVersionManager::GetStoragePointers() { + lock_guard lock(version_lock); + D_ASSERT(!has_unserialized_changes); + return storage_pointers; +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/table/scan_state.cpp b/src/duckdb/src/storage/table/scan_state.cpp index f7f5f727d..01314d7d1 100644 --- a/src/duckdb/src/storage/table/scan_state.cpp +++ b/src/duckdb/src/storage/table/scan_state.cpp @@ -133,16 +133,21 @@ void ColumnScanState::NextInternal(idx_t count) { //! There is no column segment return; } - row_index += count; - while (row_index >= current->start + current->count) { - current = segment_tree->GetNextSegment(current); + offset_in_column += count; + while (offset_in_column >= current->GetRowStart() + current->GetNode().count) { + current = segment_tree->GetNextSegment(*current); initialized = false; segment_checked = false; if (!current) { break; } } - D_ASSERT(!current || (row_index >= current->start && row_index < current->start + current->count)); + D_ASSERT(!current || (offset_in_column >= current->GetRowStart() && + offset_in_column < current->GetRowStart() + current->GetNode().count)); +} + +idx_t ColumnScanState::GetPositionInSegment() const { + return offset_in_column - (current ? current->GetRowStart() : 0); } void ColumnScanState::Next(idx_t count) { @@ -174,28 +179,63 @@ ParallelCollectionScanState::ParallelCollectionScanState() : collection(nullptr), current_row_group(nullptr), processed_rows(0) { } +optional_ptr> ParallelCollectionScanState::GetRootSegment(RowGroupSegmentTree &row_groups) const { + if (reorderer) { + return reorderer->GetRootSegment(row_groups); + } + return row_groups.GetRootSegment(); +} + +optional_ptr> +ParallelCollectionScanState::GetNextRowGroup(RowGroupSegmentTree &row_groups, SegmentNode &row_group) const { + if (reorderer) { + return reorderer->GetNextRowGroup(row_group); + } + return row_groups.GetNextSegment(row_group); +} + CollectionScanState::CollectionScanState(TableScanState &parent_p) : row_group(nullptr), vector_index(0), max_row_group_row(0), row_groups(nullptr), max_row(0), batch_index(0), valid_sel(STANDARD_VECTOR_SIZE), random(-1), parent(parent_p) { } +optional_ptr> CollectionScanState::GetNextRowGroup(SegmentNode &row_group) const { + if (reorderer) { + return reorderer->GetNextRowGroup(row_group); + } + return row_groups->GetNextSegment(row_group); +} + +optional_ptr> CollectionScanState::GetNextRowGroup(SegmentLock &l, + SegmentNode &row_group) const { + D_ASSERT(!reorderer); + return row_groups->GetNextSegment(l, row_group); +} + +optional_ptr> CollectionScanState::GetRootSegment() const { + if (reorderer) { + return reorderer->GetRootSegment(*row_groups); + } + return row_groups->GetRootSegment(); +} + bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) { while (row_group) { - row_group->Scan(transaction, *this, result); + row_group->GetNode().Scan(transaction, *this, result); if (result.size() > 0) { return true; - } else if (max_row <= row_group->start + row_group->count) { + } else if (max_row <= row_group->GetRowStart() + row_group->GetNode().count) { row_group = nullptr; return false; } else { do { - row_group = row_groups->GetNextSegment(row_group); + row_group = GetNextRowGroup(*row_group).get(); if (row_group) { - if (row_group->start >= max_row) { + if (row_group->GetRowStart() >= max_row) { row_group = nullptr; break; } - bool scan_row_group = row_group->InitializeScan(*this); + bool scan_row_group = row_group->GetNode().InitializeScan(*this, *row_group); if (scan_row_group) { // scan this row group break; @@ -209,13 +249,13 @@ bool CollectionScanState::Scan(DuckTransaction &transaction, DataChunk &result) bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, TableScanType type) { while (row_group) { - row_group->ScanCommitted(*this, result, type); + row_group->GetNode().ScanCommitted(*this, result, type); if (result.size() > 0) { return true; } else { - row_group = row_groups->GetNextSegment(l, row_group); + row_group = GetNextRowGroup(l, *row_group).get(); if (row_group) { - row_group->InitializeScan(*this); + row_group->GetNode().InitializeScan(*this, *row_group); } } } @@ -224,14 +264,14 @@ bool CollectionScanState::ScanCommitted(DataChunk &result, SegmentLock &l, Table bool CollectionScanState::ScanCommitted(DataChunk &result, TableScanType type) { while (row_group) { - row_group->ScanCommitted(*this, result, type); + row_group->GetNode().ScanCommitted(*this, result, type); if (result.size() > 0) { return true; - } else { - row_group = row_groups->GetNextSegment(row_group); - if (row_group) { - row_group->InitializeScan(*this); - } + } + + row_group = GetNextRowGroup(*row_group).get(); + if (row_group) { + row_group->GetNode().InitializeScan(*this, *row_group); } } return false; diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index c657c63ee..3ddc6381b 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -12,14 +12,17 @@ namespace duckdb { StandardColumnData::StandardColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, LogicalType type, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type), parent), - validity(block_manager, info, 0, start_row, *this) { + LogicalType type, ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, std::move(type), data_type, parent) { + if (data_type != ColumnDataType::CHECKPOINT_TARGET) { + // don't initialize the child entry if this is a checkpoint target + validity = make_shared_ptr(block_manager, info, 0, *this); + } } -void StandardColumnData::SetStart(idx_t new_start) { - ColumnData::SetStart(new_start); - validity.SetStart(new_start); +void StandardColumnData::SetDataType(ColumnDataType data_type) { + ColumnData::SetDataType(data_type); + validity->SetDataType(data_type); } ScanVectorType StandardColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_count, Vector &result) { @@ -31,12 +34,12 @@ ScanVectorType StandardColumnData::GetVectorScanType(ColumnScanState &state, idx if (state.child_states.empty()) { return scan_type; } - return validity.GetVectorScanType(state.child_states[0], scan_count, result); + return validity->GetVectorScanType(state.child_states[0], scan_count, result); } void StandardColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { ColumnData::InitializePrefetch(prefetch_state, scan_state, rows); - validity.InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); + validity->InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); } void StandardColumnData::InitializeScan(ColumnScanState &state) { @@ -44,7 +47,7 @@ void StandardColumnData::InitializeScan(ColumnScanState &state) { // initialize the validity segment D_ASSERT(state.child_states.size() == 1); - validity.InitializeScan(state.child_states[0]); + validity->InitializeScan(state.child_states[0]); } void StandardColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { @@ -52,30 +55,30 @@ void StandardColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t // initialize the validity segment D_ASSERT(state.child_states.size() == 1); - validity.InitializeScanWithOffset(state.child_states[0], row_idx); + validity->InitializeScanWithOffset(state.child_states[0], row_idx); } idx_t StandardColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, idx_t target_count) { - D_ASSERT(state.row_index == state.child_states[0].row_index); + D_ASSERT(state.offset_in_column == state.child_states[0].offset_in_column); auto scan_type = GetVectorScanType(state, target_count, result); auto mode = ScanVectorMode::REGULAR_SCAN; auto scan_count = ScanVector(transaction, vector_index, state, result, target_count, scan_type, mode); - validity.ScanVector(transaction, vector_index, state.child_states[0], result, target_count, scan_type, mode); + validity->ScanVector(transaction, vector_index, state.child_states[0], result, target_count, scan_type, mode); return scan_count; } idx_t StandardColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates, idx_t target_count) { - D_ASSERT(state.row_index == state.child_states[0].row_index); + D_ASSERT(state.offset_in_column == state.child_states[0].offset_in_column); auto scan_count = ColumnData::ScanCommitted(vector_index, state, result, allow_updates, target_count); - validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates, target_count); + validity->ScanCommitted(vector_index, state.child_states[0], result, allow_updates, target_count); return scan_count; } idx_t StandardColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset) { auto scan_count = ColumnData::ScanCount(state, result, count, result_offset); - validity.ScanCount(state.child_states[0], result, count, result_offset); + validity->ScanCount(state.child_states[0], result, count, result_offset); return scan_count; } @@ -86,7 +89,7 @@ void StandardColumnData::Filter(TransactionData transaction, idx_t vector_index, // the compression functions need to support this auto compression = GetCompressionFunction(); bool has_filter = compression && compression->filter; - auto validity_compression = validity.GetCompressionFunction(); + auto validity_compression = validity->GetCompressionFunction(); bool validity_has_filter = validity_compression && validity_compression->filter; auto target_count = GetVectorCount(vector_index); auto scan_type = GetVectorScanType(state, target_count, result); @@ -98,7 +101,7 @@ void StandardColumnData::Filter(TransactionData transaction, idx_t vector_index, return; } FilterVector(state, result, target_count, sel, count, filter, filter_state); - validity.FilterVector(state.child_states[0], result, target_count, sel, count, filter, filter_state); + validity->FilterVector(state.child_states[0], result, target_count, sel, count, filter, filter_state); } void StandardColumnData::Select(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, @@ -107,7 +110,7 @@ void StandardColumnData::Select(TransactionData transaction, idx_t vector_index, // the compression functions need to support this auto compression = GetCompressionFunction(); bool has_select = compression && compression->select; - auto validity_compression = validity.GetCompressionFunction(); + auto validity_compression = validity->GetCompressionFunction(); bool validity_has_select = validity_compression && validity_compression->select; auto target_count = GetVectorCount(vector_index); auto scan_type = GetVectorScanType(state, target_count, result); @@ -118,68 +121,75 @@ void StandardColumnData::Select(TransactionData transaction, idx_t vector_index, return; } SelectVector(state, result, target_count, sel, sel_count); - validity.SelectVector(state.child_states[0], result, target_count, sel, sel_count); + validity->SelectVector(state.child_states[0], result, target_count, sel, sel_count); } void StandardColumnData::InitializeAppend(ColumnAppendState &state) { ColumnData::InitializeAppend(state); ColumnAppendState child_append; - validity.InitializeAppend(child_append); + validity->InitializeAppend(child_append); state.child_appends.push_back(std::move(child_append)); } void StandardColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { ColumnData::AppendData(stats, state, vdata, count); - validity.AppendData(stats, state.child_appends[0], vdata, count); + validity->AppendData(stats, state.child_appends[0], vdata, count); } -void StandardColumnData::RevertAppend(row_t start_row) { - ColumnData::RevertAppend(start_row); - - validity.RevertAppend(start_row); +void StandardColumnData::RevertAppend(row_t new_count) { + ColumnData::RevertAppend(new_count); + validity->RevertAppend(new_count); } idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { // fetch validity mask if (state.child_states.empty()) { - ColumnScanState child_state; + ColumnScanState child_state(state.parent); child_state.scan_options = state.scan_options; state.child_states.push_back(std::move(child_state)); } auto scan_count = ColumnData::Fetch(state, row_id, result); - validity.Fetch(state.child_states[0], row_id, result); + validity->Fetch(state.child_states[0], row_id, result); return scan_count; } -void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - ColumnScanState standard_state, validity_state; +void StandardColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { + ColumnScanState standard_state(nullptr); + ColumnScanState validity_state(nullptr); Vector base_vector(type); - auto standard_fetch = FetchUpdateData(standard_state, row_ids, base_vector); - auto validity_fetch = validity.FetchUpdateData(validity_state, row_ids, base_vector); + auto standard_fetch = FetchUpdateData(standard_state, row_ids, base_vector, row_group_start); + auto validity_fetch = validity->FetchUpdateData(validity_state, row_ids, base_vector, row_group_start); if (standard_fetch != validity_fetch) { throw InternalException("Unaligned fetch in validity and main column data for update"); } - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); - validity.UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector, + row_group_start); + validity->UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector, + row_group_start); } -void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StandardColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { if (depth >= column_path.size()) { // update this column - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count, + row_group_start); } else { // update the child column (i.e. the validity column) - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + validity->UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, update_count, depth + 1, + row_group_start); + validity->UpdateWithBase(transaction, data_table, column_path[0], update_vector, row_ids, update_count, *this, + row_group_start); } } unique_ptr StandardColumnData::GetUpdateStatistics() { auto stats = updates ? updates->GetStatistics() : nullptr; - auto validity_stats = validity.GetUpdateStatistics(); + auto validity_stats = validity->GetUpdateStatistics(); if (!stats && !validity_stats) { return nullptr; } @@ -199,17 +209,25 @@ void StandardColumnData::FetchRow(TransactionData transaction, ColumnFetchState auto child_state = make_uniq(); state.child_states.push_back(std::move(child_state)); } - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); ColumnData::FetchRow(transaction, state, row_id, result, result_idx); + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); } -void StandardColumnData::CommitDropColumn() { - ColumnData::CommitDropColumn(); - validity.CommitDropColumn(); +void StandardColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + ColumnData::VisitBlockIds(visitor); + validity->VisitBlockIds(visitor); +} + +void StandardColumnData::SetValidityData(shared_ptr validity_p) { + if (validity) { + throw InternalException("StandardColumnData::SetValidityData cannot be used to overwrite existing validity"); + } + validity_p->SetParent(this); + this->validity = std::move(validity_p); } struct StandardColumnCheckpointState : public ColumnCheckpointState { - StandardColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, + StandardColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) : ColumnCheckpointState(row_group, column_data, partial_block_manager) { } @@ -217,8 +235,24 @@ struct StandardColumnCheckpointState : public ColumnCheckpointState { unique_ptr validity_state; public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, original_column.type, + ColumnDataType::CHECKPOINT_TARGET, nullptr); + } + + shared_ptr GetFinalResult() override { + if (result_column) { + auto &column_data = result_column->Cast(); + auto validity_child = validity_state->GetFinalResult(); + column_data.SetValidityData(shared_ptr_cast(std::move(validity_child))); + } + return ColumnCheckpointState::GetFinalResult(); + } + unique_ptr GetStatistics() override { D_ASSERT(global_stats); + global_stats->Merge(*validity_state->GetStatistics()); return std::move(global_stats); } @@ -230,28 +264,28 @@ struct StandardColumnCheckpointState : public ColumnCheckpointState { }; unique_ptr -StandardColumnData::CreateCheckpointState(RowGroup &row_group, PartialBlockManager &partial_block_manager) { +StandardColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { return make_uniq(row_group, *this, partial_block_manager); } -unique_ptr StandardColumnData::Checkpoint(RowGroup &row_group, +unique_ptr StandardColumnData::Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { // we need to checkpoint the main column data first // that is because the checkpointing of the main column data ALSO scans the validity data // to prevent reading the validity data immediately after it is checkpointed we first checkpoint the main column // this is necessary for concurrent checkpointing as due to the partial block manager checkpointed data might be // flushed to disk by a different thread than the one that wrote it, causing a data race - auto base_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto base_state = CreateCheckpointState(row_group, partial_block_manager); base_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); - auto validity_state_p = validity.CreateCheckpointState(row_group, checkpoint_info.info.manager); - validity_state_p->global_stats = BaseStatistics::CreateEmpty(validity.type).ToUnique(); + auto validity_state_p = validity->CreateCheckpointState(row_group, partial_block_manager); + validity_state_p->global_stats = BaseStatistics::CreateEmpty(validity->type).ToUnique(); auto &validity_state = *validity_state_p; auto &checkpoint_state = base_state->Cast(); checkpoint_state.validity_state = std::move(validity_state_p); - auto &nodes = data.ReferenceSegments(); - if (nodes.empty()) { + if (!data.GetRootSegment()) { // empty table: flush the empty list return base_state; } @@ -264,47 +298,51 @@ unique_ptr StandardColumnData::Checkpoint(RowGroup &row_g checkpointer.Checkpoint(); checkpointer.FinalizeCheckpoint(); + // merge validity stats into base stats + base_state->global_stats->Merge(*validity_state.global_stats); + return base_state; } -void StandardColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, - idx_t count, Vector &scan_vector) { - ColumnData::CheckpointScan(segment, state, row_group_start, count, scan_vector); +void StandardColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t count, + Vector &scan_vector) const { + ColumnData::CheckpointScan(segment, state, count, scan_vector); - idx_t offset_in_row_group = state.row_index - row_group_start; - validity.ScanCommittedRange(row_group_start, offset_in_row_group, count, scan_vector); + idx_t offset_in_row_group = state.offset_in_column; + validity->ScanCommittedRange(0, offset_in_row_group, count, scan_vector); } bool StandardColumnData::IsPersistent() { - return ColumnData::IsPersistent() && validity.IsPersistent(); + return ColumnData::IsPersistent() && validity->IsPersistent(); } bool StandardColumnData::HasAnyChanges() const { - return ColumnData::HasAnyChanges() || validity.HasAnyChanges(); + return ColumnData::HasAnyChanges() || validity->HasAnyChanges(); } PersistentColumnData StandardColumnData::Serialize() { auto persistent_data = ColumnData::Serialize(); - persistent_data.child_columns.push_back(validity.Serialize()); + persistent_data.child_columns.push_back(validity->Serialize()); return persistent_data; } void StandardColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) { ColumnData::InitializeColumn(column_data, target_stats); - validity.InitializeColumn(column_data.child_columns[0], target_stats); + validity->InitializeColumn(column_data.child_columns[0], target_stats); } -void StandardColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, +void StandardColumnData::GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, std::move(col_path), result); + validity->GetColumnSegmentInfo(context, row_group_index, std::move(col_path), result); } void StandardColumnData::Verify(RowGroup &parent) { #ifdef DEBUG ColumnData::Verify(parent); - validity.Verify(parent); + validity->Verify(parent); #endif } diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 5137330ef..74c509c3f 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -10,9 +10,8 @@ namespace duckdb { StructColumnData::StructColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, LogicalType type_p, optional_ptr parent) - : ColumnData(block_manager, info, column_index, start_row, std::move(type_p), parent), - validity(block_manager, info, 0, start_row, *this) { + LogicalType type_p, ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, std::move(type_p), data_type, parent) { D_ASSERT(type.InternalType() == PhysicalType::STRUCT); auto &child_types = StructType::GetChildTypes(type); D_ASSERT(!child_types.empty()); @@ -22,21 +21,27 @@ StructColumnData::StructColumnData(BlockManager &block_manager, DataTableInfo &i if (type.id() == LogicalTypeId::VARIANT) { throw NotImplementedException("A table cannot be created from a VARIANT column yet"); } - // the sub column index, starting at 1 (0 is the validity mask) - idx_t sub_column_index = 1; - for (auto &child_type : child_types) { - sub_columns.push_back( - ColumnData::CreateColumnUnique(block_manager, info, sub_column_index, start_row, child_type.second, this)); - sub_column_index++; + if (data_type != ColumnDataType::CHECKPOINT_TARGET) { + validity = make_shared_ptr(block_manager, info, 0, *this); + // the sub column index, starting at 1 (0 is the validity mask) + idx_t sub_column_index = 1; + for (auto &child_type : child_types) { + sub_columns.push_back( + ColumnData::CreateColumn(block_manager, info, sub_column_index, child_type.second, data_type, this)); + sub_column_index++; + } + } else { + // initialize to empty + sub_columns.resize(child_types.size()); } } -void StructColumnData::SetStart(idx_t new_start) { - this->start = new_start; +void StructColumnData::SetDataType(ColumnDataType data_type) { + ColumnData::SetDataType(data_type); for (auto &sub_column : sub_columns) { - sub_column->SetStart(new_start); + sub_column->SetDataType(data_type); } - validity.SetStart(new_start); + validity->SetDataType(data_type); } idx_t StructColumnData::GetMaxEntry() { @@ -44,7 +49,7 @@ idx_t StructColumnData::GetMaxEntry() { } void StructColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { - validity.InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); + validity->InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); for (idx_t i = 0; i < sub_columns.size(); i++) { if (!scan_state.scan_child_column[i]) { continue; @@ -55,11 +60,11 @@ void StructColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnS void StructColumnData::InitializeScan(ColumnScanState &state) { D_ASSERT(state.child_states.size() == sub_columns.size() + 1); - state.row_index = 0; + state.offset_in_column = 0; state.current = nullptr; // initialize the validity segment - validity.InitializeScan(state.child_states[0]); + validity->InitializeScan(state.child_states[0]); // initialize the sub-columns for (idx_t i = 0; i < sub_columns.size(); i++) { @@ -72,11 +77,12 @@ void StructColumnData::InitializeScan(ColumnScanState &state) { void StructColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { D_ASSERT(state.child_states.size() == sub_columns.size() + 1); - state.row_index = row_idx; + D_ASSERT(row_idx < count); + state.offset_in_column = row_idx; state.current = nullptr; // initialize the validity segment - validity.InitializeScanWithOffset(state.child_states[0], row_idx); + validity->InitializeScanWithOffset(state.child_states[0], row_idx); // initialize the sub-columns for (idx_t i = 0; i < sub_columns.size(); i++) { @@ -89,7 +95,7 @@ void StructColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t ro idx_t StructColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, idx_t target_count) { - auto scan_count = validity.Scan(transaction, vector_index, state.child_states[0], result, target_count); + auto scan_count = validity->Scan(transaction, vector_index, state.child_states[0], result, target_count); auto &child_entries = StructVector::GetEntries(result); for (idx_t i = 0; i < sub_columns.size(); i++) { auto &target_vector = *child_entries[i]; @@ -106,7 +112,7 @@ idx_t StructColumnData::Scan(TransactionData transaction, idx_t vector_index, Co idx_t StructColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates, idx_t target_count) { - auto scan_count = validity.ScanCommitted(vector_index, state.child_states[0], result, allow_updates, target_count); + auto scan_count = validity->ScanCommitted(vector_index, state.child_states[0], result, allow_updates, target_count); auto &child_entries = StructVector::GetEntries(result); for (idx_t i = 0; i < sub_columns.size(); i++) { auto &target_vector = *child_entries[i]; @@ -123,7 +129,7 @@ idx_t StructColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state } idx_t StructColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset) { - auto scan_count = validity.ScanCount(state.child_states[0], result, count); + auto scan_count = validity->ScanCount(state.child_states[0], result, count); auto &child_entries = StructVector::GetEntries(result); for (idx_t i = 0; i < sub_columns.size(); i++) { auto &target_vector = *child_entries[i]; @@ -139,7 +145,7 @@ idx_t StructColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t } void StructColumnData::Skip(ColumnScanState &state, idx_t count) { - validity.Skip(state.child_states[0], count); + validity->Skip(state.child_states[0], count); // skip inside the sub-columns for (idx_t child_idx = 0; child_idx < sub_columns.size(); child_idx++) { @@ -152,7 +158,7 @@ void StructColumnData::Skip(ColumnScanState &state, idx_t count) { void StructColumnData::InitializeAppend(ColumnAppendState &state) { ColumnAppendState validity_append; - validity.InitializeAppend(validity_append); + validity->InitializeAppend(validity_append); state.child_appends.push_back(std::move(validity_append)); for (auto &sub_column : sub_columns) { @@ -171,7 +177,7 @@ void StructColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, V } // append the null values - validity.Append(stats, state.child_appends[0], vector, count); + validity->Append(stats, state.child_appends[0], vector, count); auto &child_entries = StructVector::GetEntries(vector); for (idx_t i = 0; i < child_entries.size(); i++) { @@ -181,12 +187,12 @@ void StructColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, V this->count += count; } -void StructColumnData::RevertAppend(row_t start_row) { - validity.RevertAppend(start_row); +void StructColumnData::RevertAppend(row_t new_count) { + validity->RevertAppend(new_count); for (auto &sub_column : sub_columns) { - sub_column->RevertAppend(start_row); + sub_column->RevertAppend(new_count); } - this->count = UnsafeNumericCast(start_row) - this->start; + this->count = UnsafeNumericCast(new_count); } idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { @@ -194,12 +200,12 @@ idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resu auto &child_entries = StructVector::GetEntries(result); // insert any child states that are required for (idx_t i = state.child_states.size(); i < child_entries.size() + 1; i++) { - ColumnScanState child_state; + ColumnScanState child_state(state.parent); child_state.scan_options = state.scan_options; state.child_states.push_back(std::move(child_state)); } // fetch the validity state - idx_t scan_count = validity.Fetch(state.child_states[0], row_id, result); + idx_t scan_count = validity->Fetch(state.child_states[0], row_id, result); // fetch the sub-column states for (idx_t i = 0; i < child_entries.size(); i++) { sub_columns[i]->Fetch(state.child_states[i + 1], row_id, *child_entries[i]); @@ -207,17 +213,19 @@ idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resu return scan_count; } -void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - validity.Update(transaction, column_index, update_vector, row_ids, update_count); +void StructColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { + validity->Update(transaction, data_table, column_index, update_vector, row_ids, update_count, row_group_start); auto &child_entries = StructVector::GetEntries(update_vector); for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); + sub_columns[i]->Update(transaction, data_table, column_index, *child_entries[i], row_ids, update_count, + row_group_start); } } -void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StructColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { // we can never DIRECTLY update a struct column if (depth >= column_path.size()) { throw InternalException("Attempting to directly update a struct column - this should not be possible"); @@ -225,20 +233,21 @@ void StructColumnData::UpdateColumn(TransactionData transaction, const vectorUpdateColumn(transaction, data_table, column_path, update_vector, row_ids, update_count, depth + 1, + row_group_start); } else { if (update_column > sub_columns.size()) { throw InternalException("Update column_path out of range"); } - sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, - depth + 1); + sub_columns[update_column - 1]->UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, + update_count, depth + 1, row_group_start); } } unique_ptr StructColumnData::GetUpdateStatistics() { // check if any child column has updates auto stats = BaseStatistics::CreateEmpty(type); - auto validity_stats = validity.GetUpdateStatistics(); + auto validity_stats = validity->GetUpdateStatistics(); if (validity_stats) { stats.Merge(*validity_stats); } @@ -261,22 +270,38 @@ void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &s state.child_states.push_back(std::move(child_state)); } // fetch the validity state - validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); // fetch the sub-column states for (idx_t i = 0; i < child_entries.size(); i++) { sub_columns[i]->FetchRow(transaction, *state.child_states[i + 1], row_id, *child_entries[i], result_idx); } } -void StructColumnData::CommitDropColumn() { - validity.CommitDropColumn(); +void StructColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + validity->VisitBlockIds(visitor); for (auto &sub_column : sub_columns) { - sub_column->CommitDropColumn(); + sub_column->VisitBlockIds(visitor); + } +} + +void StructColumnData::SetValidityData(shared_ptr validity_p) { + if (validity) { + throw InternalException("StructColumnData::SetValidityData cannot be used to overwrite existing validity"); + } + validity_p->SetParent(this); + this->validity = std::move(validity_p); +} + +void StructColumnData::SetChildData(idx_t i, shared_ptr child_column_p) { + if (sub_columns[i]) { + throw InternalException("StructColumnData::SetChildData cannot be used to overwrite existing data"); } + child_column_p->SetParent(this); + this->sub_columns[i] = std::move(child_column_p); } struct StructColumnCheckpointState : public ColumnCheckpointState { - StructColumnCheckpointState(RowGroup &row_group, ColumnData &column_data, + StructColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) : ColumnCheckpointState(row_group, column_data, partial_block_manager) { global_stats = StructStats::CreateEmpty(column_data.type).ToUnique(); @@ -286,8 +311,27 @@ struct StructColumnCheckpointState : public ColumnCheckpointState { vector> child_states; public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, original_column.type, + ColumnDataType::CHECKPOINT_TARGET, nullptr); + } + + shared_ptr GetFinalResult() override { + if (!result_column) { + result_column = CreateEmptyColumnData(); + } + auto &column_data = result_column->Cast(); + auto validity_child = validity_state->GetFinalResult(); + column_data.SetValidityData(shared_ptr_cast(std::move(validity_child))); + for (idx_t i = 0; i < child_states.size(); i++) { + column_data.SetChildData(i, child_states[i]->GetFinalResult()); + } + return ColumnCheckpointState::GetFinalResult(); + } unique_ptr GetStatistics() override { D_ASSERT(global_stats); + global_stats->Merge(*validity_state->GetStatistics()); for (idx_t i = 0; i < child_states.size(); i++) { StructStats::SetChildStats(*global_stats, i, child_states[i]->GetStatistics()); } @@ -295,7 +339,7 @@ struct StructColumnCheckpointState : public ColumnCheckpointState { } PersistentColumnData ToPersistentData() override { - PersistentColumnData data(PhysicalType::STRUCT); + PersistentColumnData data(original_column.type); data.child_columns.push_back(validity_state->ToPersistentData()); for (auto &child_state : child_states) { data.child_columns.push_back(child_state->ToPersistentData()); @@ -304,15 +348,16 @@ struct StructColumnCheckpointState : public ColumnCheckpointState { } }; -unique_ptr StructColumnData::CreateCheckpointState(RowGroup &row_group, +unique_ptr StructColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { return make_uniq(row_group, *this, partial_block_manager); } -unique_ptr StructColumnData::Checkpoint(RowGroup &row_group, +unique_ptr StructColumnData::Checkpoint(const RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); - checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); + checkpoint_state->validity_state = validity->Checkpoint(row_group, checkpoint_info); for (auto &sub_column : sub_columns) { checkpoint_state->child_states.push_back(sub_column->Checkpoint(row_group, checkpoint_info)); } @@ -320,7 +365,7 @@ unique_ptr StructColumnData::Checkpoint(RowGroup &row_gro } bool StructColumnData::IsPersistent() { - if (!validity.IsPersistent()) { + if (!validity->IsPersistent()) { return false; } for (auto &child_col : sub_columns) { @@ -332,7 +377,7 @@ bool StructColumnData::IsPersistent() { } bool StructColumnData::HasAnyChanges() const { - if (validity.HasAnyChanges()) { + if (validity->HasAnyChanges()) { return true; } for (auto &child_col : sub_columns) { @@ -344,8 +389,8 @@ bool StructColumnData::HasAnyChanges() const { } PersistentColumnData StructColumnData::Serialize() { - PersistentColumnData persistent_data(PhysicalType::STRUCT); - persistent_data.child_columns.push_back(validity.Serialize()); + PersistentColumnData persistent_data(type); + persistent_data.child_columns.push_back(validity->Serialize()); for (auto &sub_column : sub_columns) { persistent_data.child_columns.push_back(sub_column->Serialize()); } @@ -353,28 +398,28 @@ PersistentColumnData StructColumnData::Serialize() { } void StructColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) { - validity.InitializeColumn(column_data.child_columns[0], target_stats); + validity->InitializeColumn(column_data.child_columns[0], target_stats); for (idx_t c_idx = 0; c_idx < sub_columns.size(); c_idx++) { auto &child_stats = StructStats::GetChildStats(target_stats, c_idx); sub_columns[c_idx]->InitializeColumn(column_data.child_columns[c_idx + 1], child_stats); } - this->count = validity.count.load(); + this->count = validity->count.load(); } -void StructColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { +void StructColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); for (idx_t i = 0; i < sub_columns.size(); i++) { col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(row_group_index, col_path, result); + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } void StructColumnData::Verify(RowGroup &parent) { #ifdef DEBUG ColumnData::Verify(parent); - validity.Verify(parent); + validity->Verify(parent); for (auto &sub_column : sub_columns) { sub_column->Verify(parent); } diff --git a/src/duckdb/src/storage/table/table_statistics.cpp b/src/duckdb/src/storage/table/table_statistics.cpp index 61d85319c..edcd4d04e 100644 --- a/src/duckdb/src/storage/table/table_statistics.cpp +++ b/src/duckdb/src/storage/table/table_statistics.cpp @@ -23,6 +23,33 @@ void TableStatistics::Initialize(const vector &types, PersistentTab } // LCOV_EXCL_STOP } +void TableStatistics::InitializeEmpty(const TableStatistics &other) { + D_ASSERT(Empty()); + D_ASSERT(!table_sample); + + stats_lock = make_shared_ptr(); + if (other.table_sample) { + D_ASSERT(other.table_sample->type == SampleType::RESERVOIR_SAMPLE); + auto &res = other.table_sample->Cast(); + table_sample = res.Copy(); + } else { + table_sample = make_uniq(static_cast(FIXED_SAMPLE_SIZE)); + } + + for (auto &stats : other.column_stats) { + auto new_column_stats = ColumnStatistics::CreateEmptyStats(stats->Statistics().GetType()); + if (stats->HasDistinctStats()) { + new_column_stats->SetDistinct(stats->DistinctStats().Copy()); + } + + auto &base_stats = new_column_stats->Statistics(); + if (new_column_stats->HasDistinctStats()) { + base_stats.SetDistinctCount(new_column_stats->DistinctStats().GetCount()); + } + column_stats.push_back(new_column_stats); + } +} + void TableStatistics::InitializeEmpty(const vector &types) { D_ASSERT(Empty()); D_ASSERT(!table_sample); @@ -161,6 +188,12 @@ void TableStatistics::DestroyTableSample(TableStatisticsLock &lock) const { } } +void TableStatistics::SetStats(TableStatistics &other) { + TableStatisticsLock lock(*stats_lock); + column_stats = std::move(other.column_stats); + table_sample = std::move(other.table_sample); +} + unique_ptr TableStatistics::CopyStats(idx_t i) { lock_guard l(*stats_lock); auto result = column_stats[i]->Statistics().Copy(); diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp index 8056907bc..6c6338f41 100644 --- a/src/duckdb/src/storage/table/update_segment.cpp +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -7,6 +7,7 @@ #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/update_info.hpp" #include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/storage/data_table.hpp" #include @@ -104,9 +105,12 @@ idx_t UpdateInfo::GetAllocSize(idx_t type_size) { return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); } -void UpdateInfo::Initialize(UpdateInfo &info, transaction_t transaction_id) { +void UpdateInfo::Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id, + idx_t row_group_start) { info.max = STANDARD_VECTOR_SIZE; + info.row_group_start = row_group_start; info.version_number = transaction_id; + info.table = &data_table; info.segment = nullptr; info.prev.entry = nullptr; info.next.entry = nullptr; @@ -382,6 +386,8 @@ void UpdateSegment::FetchCommittedRange(idx_t start_row, idx_t count, Vector &re if (!root) { return; } + D_ASSERT(start_row <= column_data.count); + D_ASSERT(start_row + count <= column_data.count); D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); idx_t end_row = start_row + count; @@ -485,13 +491,16 @@ static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type } void UpdateSegment::FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx) { - idx_t vector_index = (row_id - column_data.start) / STANDARD_VECTOR_SIZE; + if (row_id > column_data.count) { + throw InternalException("UpdateSegment::FetchRow out of range"); + } + idx_t vector_index = row_id / STANDARD_VECTOR_SIZE; auto lock_handle = lock.GetSharedLock(); auto entry = GetUpdateNode(*lock_handle, vector_index); if (!entry.IsSet()) { return; } - idx_t row_in_vector = (row_id - column_data.start) - vector_index * STANDARD_VECTOR_SIZE; + idx_t row_in_vector = row_id - vector_index * STANDARD_VECTOR_SIZE; auto pin = entry.Pin(); fetch_row_function(transaction.start_time, transaction.transaction_id, UpdateInfo::Get(pin), row_in_vector, result, result_idx); @@ -814,8 +823,8 @@ struct ExtractValidityEntry { template static void MergeUpdateLoopInternal(UpdateInfo &base_info, V *base_table_data, UpdateInfo &update_info, const SelectionVector &update_vector_sel, const V *update_vector_data, row_t *ids, - idx_t count, const SelectionVector &sel) { - auto base_id = base_info.segment->column_data.start + base_info.vector_index * STANDARD_VECTOR_SIZE; + idx_t count, const SelectionVector &sel, idx_t row_group_start) { + auto base_id = row_group_start + base_info.vector_index * STANDARD_VECTOR_SIZE; #ifdef DEBUG // all of these should be sorted, otherwise the below algorithm does not work for (idx_t i = 1; i < count; i++) { @@ -918,20 +927,22 @@ static void MergeUpdateLoopInternal(UpdateInfo &base_info, V *base_table_data, U } static void MergeValidityLoop(UpdateInfo &base_info, Vector &base_data, UpdateInfo &update_info, - UnifiedVectorFormat &update, row_t *ids, idx_t count, const SelectionVector &sel) { + UnifiedVectorFormat &update, row_t *ids, idx_t count, const SelectionVector &sel, + idx_t row_group_start) { auto &base_validity = FlatVector::Validity(base_data); auto &update_validity = update.validity; - MergeUpdateLoopInternal(base_info, &base_validity, update_info, - *update.sel, &update_validity, ids, count, sel); + MergeUpdateLoopInternal( + base_info, &base_validity, update_info, *update.sel, &update_validity, ids, count, sel, row_group_start); } template static void MergeUpdateLoop(UpdateInfo &base_info, Vector &base_data, UpdateInfo &update_info, - UnifiedVectorFormat &update, row_t *ids, idx_t count, const SelectionVector &sel) { + UnifiedVectorFormat &update, row_t *ids, idx_t count, const SelectionVector &sel, + idx_t row_group_start) { auto base_table_data = FlatVector::GetData(base_data); auto update_vector_data = update.GetData(update); MergeUpdateLoopInternal(base_info, base_table_data, update_info, *update.sel, update_vector_data, ids, count, - sel); + sel, row_group_start); } static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type) { @@ -1004,6 +1015,7 @@ idx_t TemplatedUpdateNumericStatistics(UpdateSegment *segment, SegmentStatistics auto &mask = update.validity; if (mask.AllValid()) { + stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); stats.statistics.UpdateNumericStats(update_data[idx]); @@ -1016,8 +1028,11 @@ idx_t TemplatedUpdateNumericStatistics(UpdateSegment *segment, SegmentStatistics for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); if (mask.RowIsValid(idx)) { + stats.statistics.SetHasNoNullFast(); sel.set_index(not_null_count++, i); stats.statistics.UpdateNumericStats(update_data[idx]); + } else { + stats.statistics.SetHasNullFast(); } } return not_null_count; @@ -1029,6 +1044,7 @@ idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, U auto update_data = update.GetDataNoConst(update); auto &mask = update.validity; if (mask.AllValid()) { + stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); auto &str = update_data[idx]; @@ -1045,12 +1061,15 @@ idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, U for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); if (mask.RowIsValid(idx)) { + stats.statistics.SetHasNoNullFast(); sel.set_index(not_null_count++, i); auto &str = update_data[idx]; StringStats::Update(stats.statistics, str); if (!str.IsInlined()) { update_data[idx] = segment->GetStringHeap().AddBlob(str); } + } else { + stats.statistics.SetHasNullFast(); } } if (not_null_count == count) { @@ -1236,11 +1255,11 @@ static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) return pos; } -UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, - unsafe_unique_array &data) { +UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, DataTable &data_table, idx_t type_size, idx_t count, + unsafe_unique_array &data, idx_t row_group_start) { data = make_unsafe_uniq_array_uninitialized(UpdateInfo::GetAllocSize(type_size)); auto update_info = reinterpret_cast(data.get()); - UpdateInfo::Initialize(*update_info, transaction.transaction_id); + UpdateInfo::Initialize(*update_info, data_table, transaction.transaction_id, row_group_start); return update_info; } @@ -1258,8 +1277,8 @@ void UpdateSegment::InitializeUpdateInfo(idx_t vector_idx) { } } -void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update_p, row_t *ids, idx_t count, - Vector &base_data) { +void UpdateSegment::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_p, + row_t *ids, idx_t count, Vector &base_data, idx_t row_group_start) { // obtain an exclusive lock auto write_lock = lock.GetExclusiveLock(); @@ -1286,8 +1305,8 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // get the vector index based on the first id // we assert that all updates must be part of the same vector auto first_id = ids[sel.get_index(0)]; - idx_t vector_index = (UnsafeNumericCast(first_id) - column_data.start) / STANDARD_VECTOR_SIZE; - idx_t vector_offset = column_data.start + vector_index * STANDARD_VECTOR_SIZE; + idx_t vector_index = (UnsafeNumericCast(first_id) - row_group_start) / STANDARD_VECTOR_SIZE; + idx_t vector_offset = row_group_start + vector_index * STANDARD_VECTOR_SIZE; if (!root || vector_index >= root->info.size() || !root->info[vector_index].IsSet()) { // get a list of effective updates - i.e. updates that actually change rows @@ -1302,7 +1321,7 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect InitializeUpdateInfo(vector_index); - D_ASSERT(idx_t(first_id) >= column_data.start); + D_ASSERT(idx_t(first_id) >= row_group_start); if (root->info[vector_index].IsSet()) { // there is already a version here, check if there are any conflicts and search for the node that belongs to @@ -1322,10 +1341,11 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // no updates made yet by this transaction: initially the update info to empty if (transaction.transaction) { auto &dtransaction = transaction.transaction->Cast(); - node_ref = dtransaction.CreateUpdateInfo(type_size, count); + node_ref = dtransaction.CreateUpdateInfo(type_size, data_table, count, row_group_start); node = &UpdateInfo::Get(node_ref); } else { - node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + node = + CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data, row_group_start); } node->segment = this; node->vector_index = vector_index; @@ -1349,18 +1369,17 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect node->Verify(); // now we are going to perform the merge - merge_update_function(base_info, base_data, *node, update_format, ids, count, sel); + merge_update_function(base_info, base_data, *node, update_format, ids, count, sel, row_group_start); base_info.Verify(); node->Verify(); } else { - // there is no version info yet: create the top level update info and fill it with the updates // allocate space for the UpdateInfo in the allocator idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto handle = root->allocator.Allocate(alloc_size); auto &update_info = UpdateInfo::Get(handle); - UpdateInfo::Initialize(update_info, TRANSACTION_ID_START - 1); + UpdateInfo::Initialize(update_info, data_table, TRANSACTION_ID_START - 1, row_group_start); update_info.column_index = column_index; InitializeUpdateInfo(update_info, ids, sel, count, vector_index, vector_offset); @@ -1370,10 +1389,11 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect UndoBufferReference node_ref; optional_ptr transaction_node; if (transaction.transaction) { - node_ref = transaction.transaction->CreateUpdateInfo(type_size, count); + node_ref = transaction.transaction->CreateUpdateInfo(type_size, data_table, count, row_group_start); transaction_node = &UpdateInfo::Get(node_ref); } else { - transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + transaction_node = + CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data, row_group_start); } InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); diff --git a/src/duckdb/src/storage/table/validity_column_data.cpp b/src/duckdb/src/storage/table/validity_column_data.cpp index fc8a9e1ea..81e6fda4d 100644 --- a/src/duckdb/src/storage/table/validity_column_data.cpp +++ b/src/duckdb/src/storage/table/validity_column_data.cpp @@ -1,21 +1,71 @@ #include "duckdb/storage/table/validity_column_data.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table/standard_column_data.hpp" namespace duckdb { ValidityColumnData::ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, - idx_t start_row, ColumnData &parent) - : ColumnData(block_manager, info, column_index, start_row, LogicalType(LogicalTypeId::VALIDITY), &parent) { + ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, LogicalType(LogicalTypeId::VALIDITY), data_type, parent) { +} + +ValidityColumnData::ValidityColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + ColumnData &parent) + : ValidityColumnData(block_manager, info, column_index, parent.GetDataType(), parent) { } FilterPropagateResult ValidityColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } +void ValidityColumnData::UpdateWithBase(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, ColumnData &base, + idx_t row_group_start) { + Vector base_vector(base.type); + ColumnScanState validity_scan_state(nullptr); + FetchUpdateData(validity_scan_state, row_ids, base_vector, row_group_start); + if (validity_scan_state.current.get()->ReferenceNode().get()->GetCompressionFunction().type == + CompressionType::COMPRESSION_EMPTY) { + // The validity is actually covered by the data, so we read it to get the validity for UpdateInternal. + ColumnScanState data_scan_state(nullptr); + auto fetch_count = base.Fetch(data_scan_state, row_ids[0], base_vector); + base_vector.Flatten(fetch_count); + } + + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector, + row_group_start); +} + void ValidityColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { lock_guard l(stats_lock); ColumnData::AppendData(stats, state, vdata, count); } + +struct ValidityColumnCheckpointState : public ColumnCheckpointState { + ValidityColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) + : ColumnCheckpointState(row_group, column_data, partial_block_manager) { + } + +public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, ColumnDataType::CHECKPOINT_TARGET, + nullptr); + } +}; + +unique_ptr +ValidityColumnData::CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +void ValidityColumnData::Verify(RowGroup &parent) { + D_ASSERT(HasParent()); + ColumnData::Verify(parent); +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/table/variant/variant_shredding.cpp b/src/duckdb/src/storage/table/variant/variant_shredding.cpp new file mode 100644 index 000000000..cd739926c --- /dev/null +++ b/src/duckdb/src/storage/table/variant/variant_shredding.cpp @@ -0,0 +1,728 @@ +#include "duckdb/storage/table/variant_column_data.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" +#include "duckdb/function/variant/variant_normalize.hpp" +#include "duckdb/common/serializer/varint.hpp" +#ifdef DEBUG +#include "duckdb/common/value_operations/value_operations.hpp" +#endif + +namespace duckdb { + +namespace { + +struct VariantStatsVisitor { + using result_type = void; + + static void VisitNull(VariantShreddingStats &stats, idx_t stats_column_index) { + return; + } + static void VisitBoolean(bool val, VariantShreddingStats &stats, idx_t stats_column_index) { + return; + } + + static void VisitMetadata(VariantLogicalType type_id, VariantShreddingStats &stats, idx_t stats_column_index) { + auto &column_stats = stats.GetColumnStats(stats_column_index); + column_stats.SetType(type_id); + } + + template + static void VisitInteger(T val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitFloat(float val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitDouble(double val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitUUID(hugeint_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitDate(date_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitInterval(interval_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTime(dtime_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimeNanos(dtime_ns_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimeTZ(dtime_tz_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimestampSec(timestamp_sec_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimestampMs(timestamp_ms_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimestamp(timestamp_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimestampNanos(timestamp_ns_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitTimestampTZ(timestamp_tz_t val, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void WriteStringInternal(const string_t &str, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitString(const string_t &str, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitBlob(const string_t &blob, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitBignum(const string_t &bignum, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitGeometry(const string_t &geom, VariantShreddingStats &stats, idx_t stats_column_index) { + } + static void VisitBitstring(const string_t &bits, VariantShreddingStats &stats, idx_t stats_column_index) { + } + + template + static void VisitDecimal(T val, uint32_t width, uint32_t scale, VariantShreddingStats &stats, + idx_t stats_column_index) { + auto &column_stats = stats.GetColumnStats(stats_column_index); + + auto decimal_count = column_stats.type_counts[static_cast(VariantLogicalType::DECIMAL)]; + D_ASSERT(decimal_count); + //! Visit is called after VisitMetadata, so even for the first DECIMAL value, count will already be 1 + decimal_count--; + + if (!decimal_count) { + column_stats.decimal_width = width; + column_stats.decimal_scale = scale; + column_stats.decimal_consistent = true; + return; + } + + if (!column_stats.decimal_consistent) { + return; + } + + if (width != column_stats.decimal_width || scale != column_stats.decimal_scale) { + column_stats.decimal_consistent = false; + } + } + + static void VisitArray(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantShreddingStats &stats, idx_t stats_column_index) { + auto &element_stats = stats.GetOrCreateElement(stats_column_index); + auto index = element_stats.index; + VariantVisitor::VisitArrayItems(variant, row, nested_data, stats, index); + } + + static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantShreddingStats &stats, idx_t stats_column_index) { + //! Then visit the fields in sorted order + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto source_children_idx = nested_data.children_idx + i; + + //! Add the key of the field to the result + auto keys_index = variant.GetKeysIndex(row, source_children_idx); + auto &key = variant.GetKey(row, keys_index); + + auto &child_stats = stats.GetOrCreateField(stats_column_index, key.GetString()); + auto index = child_stats.index; + + //! Visit the child value + auto values_index = variant.GetValuesIndex(row, source_children_idx); + VariantVisitor::Visit(variant, row, values_index, stats, index); + } + } + + static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantShreddingStats &stats, + idx_t stats_column_index) { + throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); + } +}; + +static unordered_set GetVariantType(const LogicalType &type) { + if (type.id() == LogicalTypeId::ANY) { + return {}; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: + return {VariantLogicalType::OBJECT}; + case LogicalTypeId::LIST: + return {VariantLogicalType::ARRAY}; + case LogicalTypeId::BOOLEAN: + return {VariantLogicalType::BOOL_TRUE, VariantLogicalType::BOOL_FALSE}; + case LogicalTypeId::TINYINT: + return {VariantLogicalType::INT8}; + case LogicalTypeId::SMALLINT: + return {VariantLogicalType::INT16}; + case LogicalTypeId::INTEGER: + return {VariantLogicalType::INT32}; + case LogicalTypeId::BIGINT: + return {VariantLogicalType::INT64}; + case LogicalTypeId::HUGEINT: + return {VariantLogicalType::INT128}; + case LogicalTypeId::UTINYINT: + return {VariantLogicalType::UINT8}; + case LogicalTypeId::USMALLINT: + return {VariantLogicalType::UINT16}; + case LogicalTypeId::UINTEGER: + return {VariantLogicalType::UINT32}; + case LogicalTypeId::UBIGINT: + return {VariantLogicalType::UINT64}; + case LogicalTypeId::UHUGEINT: + return {VariantLogicalType::UINT128}; + case LogicalTypeId::FLOAT: + return {VariantLogicalType::FLOAT}; + case LogicalTypeId::DOUBLE: + return {VariantLogicalType::DOUBLE}; + case LogicalTypeId::DECIMAL: + return {VariantLogicalType::DECIMAL}; + case LogicalTypeId::DATE: + return {VariantLogicalType::DATE}; + case LogicalTypeId::TIME: + return {VariantLogicalType::TIME_MICROS}; + case LogicalTypeId::TIME_TZ: + return {VariantLogicalType::TIME_MICROS_TZ}; + case LogicalTypeId::TIMESTAMP_TZ: + return {VariantLogicalType::TIMESTAMP_MICROS_TZ}; + case LogicalTypeId::TIMESTAMP: + return {VariantLogicalType::TIMESTAMP_MICROS}; + case LogicalTypeId::TIMESTAMP_SEC: + return {VariantLogicalType::TIMESTAMP_SEC}; + case LogicalTypeId::TIMESTAMP_MS: + return {VariantLogicalType::TIMESTAMP_MILIS}; + case LogicalTypeId::TIMESTAMP_NS: + return {VariantLogicalType::TIMESTAMP_NANOS}; + case LogicalTypeId::BLOB: + return {VariantLogicalType::BLOB}; + case LogicalTypeId::VARCHAR: + return {VariantLogicalType::VARCHAR}; + case LogicalTypeId::UUID: + return {VariantLogicalType::UUID}; + case LogicalTypeId::BIGNUM: + return {VariantLogicalType::BIGNUM}; + case LogicalTypeId::TIME_NS: + return {VariantLogicalType::TIME_NANOS}; + case LogicalTypeId::INTERVAL: + return {VariantLogicalType::INTERVAL}; + case LogicalTypeId::BIT: + return {VariantLogicalType::BITSTRING}; + case LogicalTypeId::GEOMETRY: + return {VariantLogicalType::GEOMETRY}; + default: + throw BinderException("Type '%s' can't be translated to a VARIANT type", type.ToString()); + } +} + +struct DuckDBVariantShreddingState : public VariantShreddingState { +public: + DuckDBVariantShreddingState(const LogicalType &type, idx_t total_count) + : VariantShreddingState(type, total_count), variant_types(GetVariantType(type)) { + } + ~DuckDBVariantShreddingState() override { + } + +public: + const unordered_set &GetVariantTypes() override { + return variant_types; + } + +private: + unordered_set variant_types; +}; + +struct UnshreddedValue { +public: + UnshreddedValue(uint32_t value_index, uint32_t &target_value_index, vector &&children = {}) + : source_value_index(value_index), target_value_index(target_value_index), + unshredded_children(std::move(children)) { + } + +public: + uint32_t source_value_index; + uint32_t &target_value_index; + vector unshredded_children; +}; + +struct DuckDBVariantShredding : public VariantShredding { +public: + explicit DuckDBVariantShredding(idx_t count) : VariantShredding(), unshredded_values(count) { + } + ~DuckDBVariantShredding() override = default; + +public: + void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) override; + void AnalyzeVariantValues(UnifiedVariantVectorData &variant, Vector &value, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, + DuckDBVariantShreddingState &shredding_state, idx_t count); + +public: + //! For each row of the variant, the value_index(es) of the values to write to the 'unshredded' Vector + vector> unshredded_values; +}; + +} // namespace + +void VariantColumnStatsData::SetType(VariantLogicalType type) { + type_counts[static_cast(type)]++; + total_count++; +} + +VariantColumnStatsData &VariantShreddingStats::GetOrCreateElement(idx_t parent_index) { + auto &parent_column = GetColumnStats(parent_index); + + idx_t element_stats = parent_column.element_stats; + if (parent_column.element_stats == DConstants::INVALID_INDEX) { + parent_column.element_stats = columns.size(); + element_stats = parent_column.element_stats; + columns.emplace_back(element_stats); + } + return GetColumnStats(element_stats); +} + +VariantColumnStatsData &VariantShreddingStats::GetOrCreateField(idx_t parent_index, const string &name) { + auto &parent_column = columns[parent_index]; + auto it = parent_column.field_stats.find(name); + + idx_t field_stats; + if (it == parent_column.field_stats.end()) { + it = parent_column.field_stats.emplace(name, columns.size()).first; + field_stats = it->second; + columns.emplace_back(field_stats); + } else { + field_stats = it->second; + } + return GetColumnStats(field_stats); +} + +VariantColumnStatsData &VariantShreddingStats::GetColumnStats(idx_t index) { + D_ASSERT(columns.size() > index); + return columns[index]; +} + +const VariantColumnStatsData &VariantShreddingStats::GetColumnStats(idx_t index) const { + D_ASSERT(columns.size() > index); + return columns[index]; +} + +static LogicalType ProduceShreddedType(VariantLogicalType type_id) { + switch (type_id) { + case VariantLogicalType::BOOL_TRUE: + case VariantLogicalType::BOOL_FALSE: + return LogicalTypeId::BOOLEAN; + case VariantLogicalType::INT8: + return LogicalTypeId::TINYINT; + case VariantLogicalType::INT16: + return LogicalTypeId::SMALLINT; + case VariantLogicalType::INT32: + return LogicalTypeId::INTEGER; + case VariantLogicalType::INT64: + return LogicalTypeId::BIGINT; + case VariantLogicalType::INT128: + return LogicalTypeId::HUGEINT; + case VariantLogicalType::UINT8: + return LogicalTypeId::UTINYINT; + case VariantLogicalType::UINT16: + return LogicalTypeId::USMALLINT; + case VariantLogicalType::UINT32: + return LogicalTypeId::UINTEGER; + case VariantLogicalType::UINT64: + return LogicalTypeId::UBIGINT; + case VariantLogicalType::UINT128: + return LogicalTypeId::UHUGEINT; + case VariantLogicalType::FLOAT: + return LogicalTypeId::FLOAT; + case VariantLogicalType::DOUBLE: + return LogicalTypeId::DOUBLE; + case VariantLogicalType::DECIMAL: + throw InternalException("Can't shred on DECIMAL"); + case VariantLogicalType::VARCHAR: + return LogicalTypeId::VARCHAR; + case VariantLogicalType::BLOB: + return LogicalTypeId::BLOB; + case VariantLogicalType::UUID: + return LogicalTypeId::UUID; + case VariantLogicalType::DATE: + return LogicalTypeId::DATE; + case VariantLogicalType::TIME_MICROS: + return LogicalTypeId::TIME; + case VariantLogicalType::TIME_NANOS: + return LogicalTypeId::TIME_NS; + case VariantLogicalType::TIMESTAMP_SEC: + return LogicalTypeId::TIMESTAMP_SEC; + case VariantLogicalType::TIMESTAMP_MILIS: + return LogicalTypeId::TIMESTAMP_MS; + case VariantLogicalType::TIMESTAMP_MICROS: + return LogicalTypeId::TIMESTAMP; + case VariantLogicalType::TIMESTAMP_NANOS: + return LogicalTypeId::TIMESTAMP_NS; + case VariantLogicalType::TIME_MICROS_TZ: + return LogicalTypeId::TIME_TZ; + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + return LogicalTypeId::TIMESTAMP_TZ; + case VariantLogicalType::INTERVAL: + return LogicalTypeId::INTERVAL; + case VariantLogicalType::BIGNUM: + return LogicalTypeId::BIGNUM; + case VariantLogicalType::BITSTRING: + return LogicalTypeId::BIT; + case VariantLogicalType::GEOMETRY: + return LogicalTypeId::GEOMETRY; + case VariantLogicalType::OBJECT: + case VariantLogicalType::ARRAY: + throw InternalException("Already handled above"); + default: + throw NotImplementedException("Shredding on VariantLogicalType::%s not supported yet", + EnumUtil::ToString(type_id)); + } +} + +static LogicalType SetShreddedType(const LogicalType &typed_value) { + child_list_t child_types; + child_types.emplace_back("untyped_value_index", LogicalType::UINTEGER); + child_types.emplace_back("typed_value", typed_value); + return LogicalType::STRUCT(child_types); +} + +bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData &column, LogicalType &out_type) const { + idx_t max_count = 0; + uint8_t type_index; + if (column.type_counts[0] == column.total_count) { + //! All NULL, emit INT32 + out_type = SetShreddedType(LogicalTypeId::INTEGER); + return true; + } + + //! Skip the 'VARIANT_NULL' type, we can't shred on NULL + for (uint8_t i = 1; i < static_cast(VariantLogicalType::ENUM_SIZE); i++) { + if (i == static_cast(VariantLogicalType::DECIMAL) && !column.decimal_consistent) { + //! Can't shred on DECIMAL, not consistent + continue; + } + idx_t count = column.type_counts[i]; + if (!max_count || count > max_count) { + max_count = count; + type_index = i; + } + } + + if (!max_count) { + return false; + } + + if (type_index == static_cast(VariantLogicalType::OBJECT)) { + child_list_t child_types; + for (auto &entry : column.field_stats) { + auto &child_column = GetColumnStats(entry.second); + LogicalType child_type; + if (GetShreddedTypeInternal(child_column, child_type)) { + child_types.emplace_back(entry.first, child_type); + } + } + if (child_types.empty()) { + return false; + } + auto shredded_type = LogicalType::STRUCT(child_types); + out_type = SetShreddedType(shredded_type); + return true; + } + if (type_index == static_cast(VariantLogicalType::ARRAY)) { + D_ASSERT(column.element_stats != DConstants::INVALID_INDEX); + auto &element_column = GetColumnStats(column.element_stats); + LogicalType element_type; + if (!GetShreddedTypeInternal(element_column, element_type)) { + return false; + } + auto shredded_type = LogicalType::LIST(element_type); + out_type = SetShreddedType(shredded_type); + return true; + } + if (type_index == static_cast(VariantLogicalType::DECIMAL)) { + auto shredded_type = LogicalType::DECIMAL(static_cast(column.decimal_width), + static_cast(column.decimal_scale)); + out_type = SetShreddedType(shredded_type); + return true; + } + auto type_id = static_cast(type_index); + + auto shredded_type = ProduceShreddedType(type_id); + out_type = SetShreddedType(shredded_type); + return true; +} + +LogicalType VariantShreddingStats::GetShreddedType() const { + auto &root_column = GetColumnStats(0); + + child_list_t child_types; + child_types.emplace_back("unshredded", VariantShredding::GetUnshreddedType()); + LogicalType shredded_type; + if (GetShreddedTypeInternal(root_column, shredded_type)) { + child_types.emplace_back("shredded", shredded_type); + } + return LogicalType::STRUCT(child_types); +} + +void VariantShreddingStats::Update(Vector &input, idx_t count) { + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(input, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + for (idx_t i = 0; i < count; i++) { + VariantVisitor::Visit(variant, i, 0, *this, static_cast(0)); + } +} + +static void VisitObject(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + VariantNormalizerState &state, const vector &child_indices) { + D_ASSERT(child_indices.size() <= nested_data.child_count); + //! First iterate through all fields to populate the map of key -> field + map sorted_fields; + for (auto &child_idx : child_indices) { + auto keys_index = variant.GetKeysIndex(row, nested_data.children_idx + child_idx); + auto &key = variant.GetKey(row, keys_index); + sorted_fields.emplace(key, child_idx); + } + + state.blob_size += VarintEncode(sorted_fields.size(), state.GetDestination()); + D_ASSERT(!sorted_fields.empty()); + + uint32_t children_idx = state.children_size; + uint32_t keys_idx = state.keys_size; + state.blob_size += VarintEncode(children_idx, state.GetDestination()); + state.children_size += sorted_fields.size(); + state.keys_size += sorted_fields.size(); + + //! Then visit the fields in sorted order + for (auto &entry : sorted_fields) { + auto source_children_idx = nested_data.children_idx + entry.second; + + //! Add the key of the field to the result + auto keys_index = variant.GetKeysIndex(row, source_children_idx); + auto &key = variant.GetKey(row, keys_index); + auto dict_index = state.GetOrCreateIndex(key); + state.keys_selvec.set_index(state.keys_offset + keys_idx, dict_index); + + //! Visit the child value + auto values_index = variant.GetValuesIndex(row, source_children_idx); + state.values_indexes[children_idx] = state.values_size; + state.keys_indexes[children_idx] = keys_idx; + children_idx++; + keys_idx++; + VariantVisitor::Visit(variant, row, values_index, state); + } +} + +static vector UnshreddedObjectChildren(UnifiedVariantVectorData &variant, uint32_t row, uint32_t value_index, + DuckDBVariantShreddingState &shredding_state) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, value_index); + + auto shredded_fields = shredding_state.ObjectFields(); + vector unshredded_children; + unshredded_children.reserve(nested_data.child_count); + for (uint32_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, nested_data.children_idx + i); + auto &key = variant.GetKey(row, keys_index); + if (shredded_fields.count(key)) { + continue; + } + unshredded_children.emplace_back(i); + } + return unshredded_children; +} + +//! ~~Write the unshredded values~~, also receiving the 'untyped_value_index' Vector to populate +//! Marking the rows that are shredded in the shredding state +void DuckDBVariantShredding::AnalyzeVariantValues(UnifiedVariantVectorData &variant, Vector &value, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, + DuckDBVariantShreddingState &shredding_state, idx_t count) { + auto &validity = FlatVector::Validity(value); + auto untyped_data = FlatVector::GetData(value); + + for (uint32_t i = 0; i < static_cast(count); i++) { + uint32_t value_index = 0; + if (value_index_sel) { + value_index = static_cast(value_index_sel->get_index(i)); + } + + uint32_t row = i; + if (sel) { + row = static_cast(sel->get_index(i)); + } + + uint32_t result_index = i; + if (result_sel) { + result_index = static_cast(result_sel->get_index(i)); + } + + if (variant.RowIsValid(row) && shredding_state.ValueIsShredded(variant, row, value_index)) { + shredding_state.SetShredded(row, value_index, result_index); + if (shredding_state.type.id() != LogicalTypeId::STRUCT) { + //! Value is shredded, directly write a NULL to the 'value' if the type is not an OBJECT + validity.SetInvalid(result_index); + continue; + } + + //! When the type is OBJECT, all excess fields would still need to be written to the 'value' + auto unshredded_children = UnshreddedObjectChildren(variant, row, value_index, shredding_state); + if (unshredded_children.empty()) { + //! Fully shredded object + validity.SetInvalid(result_index); + } else { + //! Deal with partially shredded objects + unshredded_values[row].emplace_back(value_index, untyped_data[result_index], + std::move(unshredded_children)); + } + continue; + } + + //! Deal with unshredded values + if (!variant.RowIsValid(row) || variant.GetTypeId(row, value_index) == VariantLogicalType::VARIANT_NULL) { + //! 0 is reserved for NULL + untyped_data[result_index] = 0; + } else { + unshredded_values[row].emplace_back(value_index, untyped_data[result_index]); + } + } +} + +//! Receive a 'shredded' result Vector, consisting of the 'untyped_value_index' and the 'typed_value' Vector +void DuckDBVariantShredding::WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) { + auto &result_type = result.GetType(); + D_ASSERT(result_type.id() == LogicalTypeId::STRUCT); + auto &child_types = StructType::GetChildTypes(result_type); + auto &child_vectors = StructVector::GetEntries(result); + D_ASSERT(child_types.size() == child_vectors.size()); + + auto &untyped_value_index = *child_vectors[0]; + auto &typed_value = *child_vectors[1]; + + DuckDBVariantShreddingState shredding_state(typed_value.GetType(), count); + AnalyzeVariantValues(variant, untyped_value_index, sel, value_index_sel, result_sel, shredding_state, count); + + SelectionVector null_values; + if (shredding_state.count) { + WriteTypedValues(variant, typed_value, shredding_state.shredded_sel, shredding_state.values_index_sel, + shredding_state.result_sel, shredding_state.count); + //! Set the rows that aren't shredded to NULL + idx_t sel_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto original_index = result_sel ? result_sel->get_index(i) : i; + if (sel_idx < shredding_state.count && shredding_state.result_sel[sel_idx] == original_index) { + sel_idx++; + continue; + } + FlatVector::SetNull(typed_value, original_index, true); + } + } else { + //! Set all rows of the typed_value to NULL, nothing is shredded on + for (idx_t i = 0; i < count; i++) { + FlatVector::SetNull(typed_value, result_sel ? result_sel->get_index(i) : i, true); + } + } +} + +void VariantColumnData::ShredVariantData(Vector &input, Vector &output, idx_t count) { + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(input, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + auto &child_vectors = StructVector::GetEntries(output); + + //! First traverse the Variant to write the shredded values and collect the 'untyped_value_index'es + DuckDBVariantShredding shredding(count); + shredding.WriteVariantValues(variant, *child_vectors[1], nullptr, nullptr, nullptr, count); + + //! Now we can write the unshredded values + auto &unshredded = *child_vectors[0]; + auto original_keys_size = ListVector::GetListSize(VariantVector::GetKeys(input)); + auto original_children_size = ListVector::GetListSize(VariantVector::GetChildren(input)); + auto original_values_size = ListVector::GetListSize(VariantVector::GetValues(input)); + + auto &keys = VariantVector::GetKeys(unshredded); + auto &children = VariantVector::GetChildren(unshredded); + auto &values = VariantVector::GetValues(unshredded); + auto &data = VariantVector::GetData(unshredded); + + ListVector::Reserve(keys, original_keys_size); + ListVector::SetListSize(keys, 0); + ListVector::Reserve(children, original_children_size); + ListVector::SetListSize(children, 0); + ListVector::Reserve(values, original_values_size); + ListVector::SetListSize(values, 0); + + auto &keys_entry = ListVector::GetEntry(keys); + OrderedOwningStringMap dictionary(StringVector::GetStringBuffer(keys_entry).GetStringAllocator()); + SelectionVector keys_selvec; + keys_selvec.Initialize(original_keys_size); + + VariantVectorData variant_data(unshredded); + for (idx_t row = 0; row < count; row++) { + auto &unshredded_values = shredding.unshredded_values[row]; + + if (unshredded_values.empty()) { + FlatVector::SetNull(unshredded, row, true); + continue; + } + + //! Allocate for the new data, use the same size as source + auto &blob_data = variant_data.blob_data[row]; + auto original_data = variant.GetData(row); + blob_data = StringVector::EmptyString(data, original_data.GetSize()); + + auto &keys_list_entry = variant_data.keys_data[row]; + keys_list_entry.offset = ListVector::GetListSize(keys); + + auto &children_list_entry = variant_data.children_data[row]; + children_list_entry.offset = ListVector::GetListSize(children); + + auto &values_list_entry = variant_data.values_data[row]; + values_list_entry.offset = ListVector::GetListSize(values); + + VariantNormalizerState normalizer_state(row, variant_data, dictionary, keys_selvec); + for (idx_t i = 0; i < unshredded_values.size(); i++) { + auto &unshredded_value = unshredded_values[i]; + auto value_index = unshredded_value.source_value_index; + + unshredded_value.target_value_index = normalizer_state.values_size + 1; + if (!unshredded_value.unshredded_children.empty()) { + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::OBJECT); + auto nested_data = VariantUtils::DecodeNestedData(variant, row, value_index); + + normalizer_state.type_ids[normalizer_state.values_size] = + static_cast(VariantLogicalType::OBJECT); + normalizer_state.byte_offsets[normalizer_state.values_size] = normalizer_state.blob_size; + normalizer_state.values_size++; + VisitObject(variant, row, nested_data, normalizer_state, unshredded_value.unshredded_children); + continue; + } + VariantVisitor::Visit(variant, row, value_index, normalizer_state); + } + blob_data.SetSizeAndFinalize(normalizer_state.blob_size, original_data.GetSize()); + keys_list_entry.length = normalizer_state.keys_size; + children_list_entry.length = normalizer_state.children_size; + values_list_entry.length = normalizer_state.values_size; + + ListVector::SetListSize(keys, ListVector::GetListSize(keys) + normalizer_state.keys_size); + ListVector::SetListSize(children, ListVector::GetListSize(children) + normalizer_state.children_size); + ListVector::SetListSize(values, ListVector::GetListSize(values) + normalizer_state.values_size); + } + + VariantUtils::FinalizeVariantKeys(unshredded, dictionary, keys_selvec, ListVector::GetListSize(keys)); + keys_entry.Slice(keys_selvec, ListVector::GetListSize(keys)); + + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + unshredded.SetVectorType(VectorType::CONSTANT_VECTOR); + } + +#ifdef DEBUG + Vector roundtrip_result(LogicalType::VARIANT(), count); + VariantColumnData::UnshredVariantData(output, roundtrip_result, count); + + for (idx_t i = 0; i < count; i++) { + auto input_val = input.GetValue(i); + auto roundtripped_val = roundtrip_result.GetValue(i); + if (!ValueOperations::NotDistinctFrom(input_val, roundtripped_val)) { + throw InternalException("Shredding roundtrip verification failed for row: %d, expected: %s, actual: %s", i, + input_val.ToString(), roundtripped_val.ToString()); + } + } + +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp new file mode 100644 index 000000000..3b8e02488 --- /dev/null +++ b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp @@ -0,0 +1,231 @@ +#include "duckdb/storage/table/variant_column_data.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/cast/variant/to_variant_fwd.hpp" +#include "duckdb/common/types/variant_value.hpp" +#include "duckdb/common/types/variant_visitor.hpp" +#include "duckdb/function/variant/variant_value_convert.hpp" + +namespace duckdb { + +template +static VariantValue UnshreddedVariantValue(UnifiedVariantVectorData &input, uint32_t row, uint32_t values_index) { + if (!input.RowIsValid(row)) { + return VariantValue(Value(LogicalTypeId::SQLNULL)); + } + + if (values_index == 0) { + //! 0 is reserved to indicate NULL, to better recognize the situation where a Variant is fully shredded, but has + //! NULLs + return VariantValue(Value(LogicalTypeId::SQLNULL)); + } + values_index--; + + auto type_id = input.GetTypeId(row, values_index); + if (!ALLOW_NULL) { + //! We don't expect NULLs at the root, those should have the 'values_index' of 0 + D_ASSERT(type_id != VariantLogicalType::VARIANT_NULL); + } + + if (type_id == VariantLogicalType::OBJECT) { + VariantValue res(VariantValueType::OBJECT); + + auto object_data = VariantUtils::DecodeNestedData(input, row, values_index); + for (idx_t i = 0; i < object_data.child_count; i++) { + auto child_values_index = input.GetValuesIndex(row, object_data.children_idx + i); + auto val = UnshreddedVariantValue(input, row, child_values_index + 1); + + auto keys_index = input.GetKeysIndex(row, object_data.children_idx + i); + auto &key = input.GetKey(row, keys_index); + + res.AddChild(key.GetString(), std::move(val)); + } + return res; + } + if (type_id == VariantLogicalType::ARRAY) { + VariantValue res(VariantValueType::ARRAY); + + auto array_data = VariantUtils::DecodeNestedData(input, row, values_index); + for (idx_t i = 0; i < array_data.child_count; i++) { + auto child_values_index = input.GetValuesIndex(row, array_data.children_idx + i); + auto val = UnshreddedVariantValue(input, row, child_values_index + 1); + + res.AddItem(std::move(val)); + } + return res; + } + auto val = VariantVisitor::Visit(input, row, values_index); + return VariantValue(std::move(val)); +} + +static vector Unshred(UnifiedVariantVectorData &variant, Vector &shredded, idx_t count, + optional_ptr row_sel); + +static vector UnshredTypedLeaf(Vector &typed_value, idx_t count) { + vector res(count); + UnifiedVectorFormat vector_format; + typed_value.ToUnifiedFormat(count, vector_format); + auto &typed_value_validity = vector_format.validity; + + for (idx_t i = 0; i < count; i++) { + if (!typed_value_validity.RowIsValid(vector_format.sel->get_index(i))) { + continue; + } + res[i] = VariantValue(typed_value.GetValue(i)); + } + return res; +} + +static vector UnshredTypedObject(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, + optional_ptr row_sel) { + vector res(count); + + auto &child_types = StructType::GetChildTypes(typed_value.GetType()); + auto &child_entries = StructVector::GetEntries(typed_value); + + D_ASSERT(child_types.size() == child_entries.size()); + + //! First unshred all children + vector> child_values(child_entries.size()); + for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { + auto &child_entry = child_entries[child_idx]; + child_values[child_idx] = Unshred(variant, *child_entry, count, row_sel); + } + + //! Then compose the OBJECT value by combining all the children + UnifiedVectorFormat vector_format; + typed_value.ToUnifiedFormat(count, vector_format); + auto &typed_value_validity = vector_format.validity; + for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { + auto &child_name = child_types[child_idx].first; + auto &values = child_values[child_idx]; + + for (idx_t i = 0; i < count; i++) { + if (!typed_value_validity.RowIsValid(vector_format.sel->get_index(i))) { + continue; + } + if (values[i].IsMissing()) { + continue; + } + if (res[i].IsMissing()) { + res[i] = VariantValue(VariantValueType::OBJECT); + } + auto &obj_value = res[i]; + obj_value.AddChild(child_name, std::move(values[i])); + } + } + return res; +} + +static vector UnshredTypedArray(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, + optional_ptr row_sel) { + auto child_size = ListVector::GetListSize(typed_value); + auto &child_vector = ListVector::GetEntry(typed_value); + + D_ASSERT(typed_value.GetType().id() == LogicalTypeId::LIST); + auto list_data = FlatVector::GetData(typed_value); + + UnifiedVectorFormat vector_format; + typed_value.ToUnifiedFormat(count, vector_format); + auto &typed_value_validity = vector_format.validity; + + SelectionVector child_sel(child_size); + for (uint32_t i = 0; i < count; i++) { + if (!typed_value_validity.RowIsValid(vector_format.sel->get_index(i))) { + continue; + } + auto row = row_sel ? static_cast(row_sel->get_index(i)) : i; + auto &list_entry = list_data[i]; + for (idx_t j = 0; j < list_entry.length; j++) { + child_sel[list_entry.offset + j] = row; + } + } + auto child_values = Unshred(variant, child_vector, child_size, child_sel); + + vector res(count); + for (idx_t i = 0; i < count; i++) { + if (!typed_value_validity.RowIsValid(vector_format.sel->get_index(i))) { + continue; + } + auto &list_entry = list_data[i]; + + auto &list_val = res[i]; + list_val = VariantValue(VariantValueType::ARRAY); + list_val.array_items.reserve(list_entry.length); + list_val.array_items.insert( + list_val.array_items.end(), + std::make_move_iterator(child_values.begin() + static_cast(list_entry.offset)), + std::make_move_iterator(child_values.begin() + + static_cast(list_entry.offset + list_entry.length))); + } + return res; +} + +static vector UnshredTypedValue(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, + optional_ptr row_sel) { + auto &type = typed_value.GetType(); + if (type.id() == LogicalTypeId::STRUCT) { + return UnshredTypedObject(variant, typed_value, count, row_sel); + } else if (type.id() == LogicalTypeId::LIST) { + return UnshredTypedArray(variant, typed_value, count, row_sel); + } else { + D_ASSERT(!type.IsNested()); + return UnshredTypedLeaf(typed_value, count); + } +} + +static vector Unshred(UnifiedVariantVectorData &variant, Vector &shredded, idx_t count, + optional_ptr row_sel) { + D_ASSERT(shredded.GetType().id() == LogicalTypeId::STRUCT); + auto &child_entries = StructVector::GetEntries(shredded); + D_ASSERT(child_entries.size() == 2); + + auto &untyped_value_index = *child_entries[0]; + auto &typed_value = *child_entries[1]; + + UnifiedVectorFormat untyped_format; + untyped_value_index.ToUnifiedFormat(count, untyped_format); + auto untyped_index_data = untyped_format.GetData(untyped_format); + auto &untyped_index_validity = untyped_format.validity; + + auto res = UnshredTypedValue(variant, typed_value, count, row_sel); + for (uint32_t i = 0; i < count; i++) { + if (!untyped_index_validity.RowIsValid(untyped_format.sel->get_index(i))) { + continue; + } + auto value_index = untyped_index_data[untyped_format.sel->get_index(i)]; + auto row = row_sel ? static_cast(row_sel->get_index(i)) : i; + auto unshredded = UnshreddedVariantValue(variant, row, value_index); + + if (res[i].IsMissing()) { + //! Unshredded, has no shredded value + res[i] = std::move(unshredded); + } else if (!unshredded.IsNull()) { + //! Partial shredding, already has a shredded value that this has to be combined into + D_ASSERT(res[i].value_type == VariantValueType::OBJECT); + D_ASSERT(unshredded.value_type == VariantValueType::OBJECT); + auto &object_children = unshredded.object_children; + for (auto &entry : object_children) { + res[i].AddChild(entry.first, std::move(entry.second)); + } + } + } + return res; +} + +void VariantColumnData::UnshredVariantData(Vector &input, Vector &output, idx_t count) { + D_ASSERT(input.GetType().id() == LogicalTypeId::STRUCT); + auto &child_vectors = StructVector::GetEntries(input); + D_ASSERT(child_vectors.size() == 2); + + auto &unshredded = *child_vectors[0]; + auto &shredded = *child_vectors[1]; + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(unshredded, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + auto variant_values = Unshred(variant, shredded, count, nullptr); + VariantValue::ToVARIANT(variant_values, output); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/variant_column_data.cpp b/src/duckdb/src/storage/table/variant_column_data.cpp new file mode 100644 index 000000000..f984465d0 --- /dev/null +++ b/src/duckdb/src/storage/table/variant_column_data.cpp @@ -0,0 +1,615 @@ +#include "duckdb/storage/table/variant_column_data.hpp" +#include "duckdb/storage/table/struct_column_data.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" +#include "duckdb/function/variant/variant_shredding.hpp" + +namespace duckdb { + +VariantColumnData::VariantColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, + LogicalType type_p, ColumnDataType data_type, optional_ptr parent) + : ColumnData(block_manager, info, column_index, std::move(type_p), data_type, parent) { + D_ASSERT(type.InternalType() == PhysicalType::STRUCT); + + if (data_type != ColumnDataType::CHECKPOINT_TARGET) { + validity = make_shared_ptr(block_manager, info, 0, *this); + // the sub column index, starting at 1 (0 is the validity mask) + idx_t sub_column_index = 1; + auto unshredded_type = VariantShredding::GetUnshreddedType(); + sub_columns.push_back( + ColumnData::CreateColumn(block_manager, info, sub_column_index++, unshredded_type, data_type, this)); + } else { + // leave empty, gets populated by 'SetChildData' + (void)validity; + (void)sub_columns; + } +} + +void VariantColumnData::CreateScanStates(ColumnScanState &state) { + //! Re-initialize the scan state, since VARIANT can have a different shape for every RowGroup + state.child_states.clear(); + + state.child_states.emplace_back(state.parent); + state.child_states[0].scan_options = state.scan_options; + + auto unshredded_type = VariantShredding::GetUnshreddedType(); + state.child_states.emplace_back(state.parent); + state.child_states[1].Initialize(state.context, unshredded_type, state.scan_options); + if (IsShredded()) { + auto &shredded_column = sub_columns[1]; + state.child_states.emplace_back(state.parent); + state.child_states[2].Initialize(state.context, shredded_column->type, state.scan_options); + } +} + +idx_t VariantColumnData::GetMaxEntry() { + return sub_columns[0]->GetMaxEntry(); +} + +void VariantColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t rows) { + validity->InitializePrefetch(prefetch_state, scan_state.child_states[0], rows); + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->InitializePrefetch(prefetch_state, scan_state.child_states[i + 1], rows); + } +} + +void VariantColumnData::InitializeScan(ColumnScanState &state) { + CreateScanStates(state); + state.current = nullptr; + + // initialize the validity segment + validity->InitializeScan(state.child_states[0]); + + // initialize the sub-columns + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->InitializeScan(state.child_states[i + 1]); + } +} + +void VariantColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) { + CreateScanStates(state); + state.current = nullptr; + + // initialize the validity segment + validity->InitializeScanWithOffset(state.child_states[0], row_idx); + + // initialize the sub-columns + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->InitializeScanWithOffset(state.child_states[i + 1], row_idx); + } +} + +Vector VariantColumnData::CreateUnshreddingIntermediate(idx_t count) { + D_ASSERT(IsShredded()); + D_ASSERT(sub_columns.size() == 2); + + child_list_t child_types; + child_types.emplace_back("unshredded", sub_columns[0]->type); + child_types.emplace_back("shredded", sub_columns[1]->type); + auto intermediate_type = LogicalType::STRUCT(child_types); + Vector intermediate(intermediate_type, count); + return intermediate; +} + +idx_t VariantColumnData::Scan(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, + idx_t target_count) { + if (IsShredded()) { + auto intermediate = CreateUnshreddingIntermediate(target_count); + auto &child_vectors = StructVector::GetEntries(intermediate); + sub_columns[0]->Scan(transaction, vector_index, state.child_states[1], *child_vectors[0], target_count); + sub_columns[1]->Scan(transaction, vector_index, state.child_states[2], *child_vectors[1], target_count); + auto scan_count = validity->Scan(transaction, vector_index, state.child_states[0], intermediate, target_count); + + VariantColumnData::UnshredVariantData(intermediate, result, target_count); + return scan_count; + } + auto scan_count = validity->Scan(transaction, vector_index, state.child_states[0], result, target_count); + sub_columns[0]->Scan(transaction, vector_index, state.child_states[1], result, target_count); + return scan_count; +} + +idx_t VariantColumnData::ScanCommitted(idx_t vector_index, ColumnScanState &state, Vector &result, bool allow_updates, + idx_t target_count) { + if (IsShredded()) { + auto intermediate = CreateUnshreddingIntermediate(target_count); + + auto &child_vectors = StructVector::GetEntries(intermediate); + sub_columns[0]->ScanCommitted(vector_index, state.child_states[1], *child_vectors[0], allow_updates, + target_count); + sub_columns[1]->ScanCommitted(vector_index, state.child_states[2], *child_vectors[1], allow_updates, + target_count); + auto scan_count = + validity->ScanCommitted(vector_index, state.child_states[0], intermediate, allow_updates, target_count); + + VariantColumnData::UnshredVariantData(intermediate, result, target_count); + return scan_count; + } + auto scan_count = + sub_columns[0]->ScanCommitted(vector_index, state.child_states[1], result, allow_updates, target_count); + return scan_count; +} + +idx_t VariantColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t count, idx_t result_offset) { + auto scan_count = sub_columns[0]->ScanCount(state.child_states[1], result, count, result_offset); + return scan_count; +} + +void VariantColumnData::Skip(ColumnScanState &state, idx_t count) { + validity->Skip(state.child_states[0], count); + + // skip inside the sub-columns + for (idx_t child_idx = 0; child_idx < sub_columns.size(); child_idx++) { + sub_columns[child_idx]->Skip(state.child_states[child_idx + 1], count); + } +} + +void VariantColumnData::InitializeAppend(ColumnAppendState &state) { + ColumnAppendState validity_append; + validity->InitializeAppend(validity_append); + state.child_appends.push_back(std::move(validity_append)); + + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + ColumnAppendState child_append; + sub_column->InitializeAppend(child_append); + state.child_appends.push_back(std::move(child_append)); + } +} + +namespace { + +struct VariantShreddedAppendInput { + ColumnData &unshredded; + ColumnData &shredded; + ColumnAppendState &unshredded_append_state; + ColumnAppendState &shredded_append_state; + BaseStatistics &unshredded_stats; + BaseStatistics &shredded_stats; +}; + +} // namespace + +static void AppendShredded(Vector &input, Vector &append_vector, idx_t count, VariantShreddedAppendInput &append_data) { + D_ASSERT(append_vector.GetType().id() == LogicalTypeId::STRUCT); + auto &child_vectors = StructVector::GetEntries(append_vector); + D_ASSERT(child_vectors.size() == 2); + + //! Create the new column data for the shredded data + VariantColumnData::ShredVariantData(input, append_vector, count); + auto &unshredded_vector = *child_vectors[0]; + auto &shredded_vector = *child_vectors[1]; + + auto &unshredded = append_data.unshredded; + auto &shredded = append_data.shredded; + + auto &unshredded_stats = append_data.unshredded_stats; + auto &shredded_stats = append_data.shredded_stats; + + auto &unshredded_append_state = append_data.unshredded_append_state; + auto &shredded_append_state = append_data.shredded_append_state; + + unshredded.Append(unshredded_stats, unshredded_append_state, unshredded_vector, count); + shredded.Append(shredded_stats, shredded_append_state, shredded_vector, count); +} + +void VariantColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { + if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { + Vector append_vector(vector); + append_vector.Flatten(count); + Append(stats, state, append_vector, count); + return; + } + + // append the null values + validity->Append(stats, state.child_appends[0], vector, count); + + if (IsShredded()) { + auto &unshredded_type = sub_columns[0]->type; + auto &shredded_type = sub_columns[1]->type; + + auto variant_shredded_type = LogicalType::STRUCT({ + {"unshredded", unshredded_type}, + {"shredded", shredded_type}, + }); + Vector append_vector(variant_shredded_type, count); + + VariantShreddedAppendInput append_data { + *sub_columns[0], + *sub_columns[1], + state.child_appends[1], + state.child_appends[2], + VariantStats::GetUnshreddedStats(stats), + VariantStats::GetShreddedStats(stats), + }; + AppendShredded(vector, append_vector, count, append_data); + } else { + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->Append(VariantStats::GetUnshreddedStats(stats), state.child_appends[i + 1], vector, count); + } + VariantStats::MarkAsNotShredded(stats); + } + this->count += count; +} + +void VariantColumnData::RevertAppend(row_t new_count) { + validity->RevertAppend(new_count); + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + sub_column->RevertAppend(new_count); + } + this->count = UnsafeNumericCast(new_count); +} + +idx_t VariantColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { + throw NotImplementedException("VARIANT Fetch"); +} + +void VariantColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) { + throw NotImplementedException("VARIANT Update is not supported."); +} + +void VariantColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth, idx_t row_group_start) { + throw NotImplementedException("VARIANT Update Column is not supported"); +} + +unique_ptr VariantColumnData::GetUpdateStatistics() { + return nullptr; +} + +void VariantColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, + idx_t result_idx) { + // insert any child states that are required + for (idx_t i = state.child_states.size(); i < sub_columns.size() + 1; i++) { + auto child_state = make_uniq(); + state.child_states.push_back(std::move(child_state)); + } + + if (IsShredded()) { + auto intermediate = CreateUnshreddingIntermediate(result_idx + 1); + auto &child_vectors = StructVector::GetEntries(intermediate); + // fetch the validity state + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + // fetch the sub-column states + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->FetchRow(transaction, *state.child_states[i + 1], row_id, *child_vectors[i], result_idx); + } + if (result_idx) { + intermediate.SetValue(0, intermediate.GetValue(result_idx)); + } + + //! FIXME: adjust UnshredVariantData so we can write the value in place into 'result' directly. + Vector unshredded(result.GetType(), 1); + VariantColumnData::UnshredVariantData(intermediate, unshredded, 1); + result.SetValue(result_idx, unshredded.GetValue(0)); + return; + } + + validity->FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); + sub_columns[0]->FetchRow(transaction, *state.child_states[1], row_id, result, result_idx); +} + +void VariantColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { + validity->VisitBlockIds(visitor); + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + sub_column->VisitBlockIds(visitor); + } +} + +void VariantColumnData::SetValidityData(shared_ptr validity_p) { + if (validity) { + throw InternalException("VariantColumnData::SetValidityData cannot be used to overwrite existing validity"); + } + validity_p->SetParent(this); + this->validity = std::move(validity_p); +} + +void VariantColumnData::SetChildData(vector> child_data) { + if (!sub_columns.empty()) { + throw InternalException("VariantColumnData::SetChildData cannot be used to overwrite existing data"); + } + for (auto &col : child_data) { + col->SetParent(this); + } + this->sub_columns = std::move(child_data); +} + +struct VariantColumnCheckpointState : public ColumnCheckpointState { + VariantColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, + PartialBlockManager &partial_block_manager) + : ColumnCheckpointState(row_group, column_data, partial_block_manager) { + global_stats = VariantStats::CreateEmpty(column_data.type).ToUnique(); + } + + vector> shredded_data; + + unique_ptr validity_state; + vector> child_states; + +public: + shared_ptr CreateEmptyColumnData() override { + return make_shared_ptr(original_column.GetBlockManager(), original_column.GetTableInfo(), + original_column.column_index, original_column.type, + ColumnDataType::CHECKPOINT_TARGET, nullptr); + } + + shared_ptr GetFinalResult() override { + if (!result_column) { + result_column = CreateEmptyColumnData(); + } + auto &column_data = result_column->Cast(); + auto validity_child = validity_state->GetFinalResult(); + column_data.SetValidityData(shared_ptr_cast(std::move(validity_child))); + vector> child_data; + for (idx_t i = 0; i < child_states.size(); i++) { + child_data.push_back(child_states[i]->GetFinalResult()); + } + column_data.SetChildData(std::move(child_data)); + return ColumnCheckpointState::GetFinalResult(); + } + + unique_ptr GetStatistics() override { + D_ASSERT(global_stats); + global_stats->Merge(*validity_state->GetStatistics()); + VariantStats::SetUnshreddedStats(*global_stats, child_states[0]->GetStatistics()); + if (child_states.size() == 2) { + VariantStats::SetShreddedStats(*global_stats, child_states[1]->GetStatistics()); + } + return std::move(global_stats); + } + + PersistentColumnData ToPersistentData() override { + PersistentColumnData data(original_column.type); + auto &variant_column_data = GetResultColumn().Cast(); + if (child_states.size() == 2) { + D_ASSERT(variant_column_data.sub_columns.size() == 2); + D_ASSERT(variant_column_data.sub_columns[1]->type.id() == LogicalTypeId::STRUCT); + data.SetVariantShreddedType(variant_column_data.sub_columns[1]->type); + } + data.child_columns.push_back(validity_state->ToPersistentData()); + for (auto &child_state : child_states) { + data.child_columns.push_back(child_state->ToPersistentData()); + } + return data; + } +}; + +unique_ptr VariantColumnData::CreateCheckpointState(const RowGroup &row_group, + PartialBlockManager &partial_block_manager) { + return make_uniq(row_group, *this, partial_block_manager); +} + +vector> VariantColumnData::WriteShreddedData(const RowGroup &row_group, + const LogicalType &shredded_type, + BaseStatistics &stats) { + //! scan_chunk + DataChunk scan_chunk; + scan_chunk.Initialize(Allocator::DefaultAllocator(), {LogicalType::VARIANT()}, STANDARD_VECTOR_SIZE); + auto &scan_vector = scan_chunk.data[0]; + + //! append_chunk + auto &child_types = StructType::GetChildTypes(shredded_type); + + DataChunk append_chunk; + append_chunk.Initialize(Allocator::DefaultAllocator(), {shredded_type}, STANDARD_VECTOR_SIZE); + auto &append_vector = append_chunk.data[0]; + + //! Create the new column data for the shredded data + D_ASSERT(child_types.size() == 2); + auto &unshredded_type = child_types[0].second; + auto &typed_value_type = child_types[1].second; + + vector> ret(2); + ret[0] = ColumnData::CreateColumn(block_manager, info, 1, unshredded_type, GetDataType(), this); + ret[1] = ColumnData::CreateColumn(block_manager, info, 2, typed_value_type, GetDataType(), this); + auto &unshredded = ret[0]; + auto &shredded = ret[1]; + + ColumnAppendState unshredded_append_state; + unshredded->InitializeAppend(unshredded_append_state); + + ColumnAppendState shredded_append_state; + shredded->InitializeAppend(shredded_append_state); + + ColumnScanState scan_state(nullptr); + + InitializeScan(scan_state); + //! Scan + transform + append + idx_t total_count = count.load(); + + auto transformed_stats = VariantStats::CreateShredded(typed_value_type); + auto &unshredded_stats = VariantStats::GetUnshreddedStats(transformed_stats); + auto &shredded_stats = VariantStats::GetShreddedStats(transformed_stats); + + VariantShreddedAppendInput append_data {*unshredded, *shredded, unshredded_append_state, + shredded_append_state, unshredded_stats, shredded_stats}; + idx_t vector_index = 0; + for (idx_t scanned = 0; scanned < total_count; scanned += STANDARD_VECTOR_SIZE) { + scan_chunk.Reset(); + auto to_scan = MinValue(total_count - scanned, static_cast(STANDARD_VECTOR_SIZE)); + ScanCommitted(vector_index++, scan_state, scan_vector, false, to_scan); + append_chunk.Reset(); + + AppendShredded(scan_vector, append_vector, to_scan, append_data); + } + stats = std::move(transformed_stats); + return ret; +} + +LogicalType VariantColumnData::GetShreddedType() { + VariantShreddingStats variant_stats; + + //! scan_chunk + DataChunk scan_chunk; + scan_chunk.Initialize(Allocator::DefaultAllocator(), {LogicalType::VARIANT()}, STANDARD_VECTOR_SIZE); + auto &scan_vector = scan_chunk.data[0]; + + ColumnScanState scan_state(nullptr); + InitializeScan(scan_state); + idx_t total_count = count.load(); + idx_t vector_index = 0; + for (idx_t scanned = 0; scanned < total_count; scanned += STANDARD_VECTOR_SIZE) { + scan_chunk.Reset(); + auto to_scan = MinValue(total_count - scanned, static_cast(STANDARD_VECTOR_SIZE)); + ScanCommitted(vector_index++, scan_state, scan_vector, false, to_scan); + variant_stats.Update(scan_vector, to_scan); + } + + return variant_stats.GetShreddedType(); +} + +static bool EnableShredding(int64_t minimum_size, idx_t current_size) { + if (minimum_size == -1) { + //! Shredding is entirely disabled + return false; + } + return current_size >= static_cast(minimum_size); +} + +unique_ptr VariantColumnData::Checkpoint(const RowGroup &row_group, + ColumnCheckpointInfo &checkpoint_info) { + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); + checkpoint_state->validity_state = validity->Checkpoint(row_group, checkpoint_info); + + auto &table_info = row_group.GetTableInfo(); + auto &db = table_info.GetDB(); + auto &config_options = DBConfig::Get(db).options; + + bool should_shred = true; + if (!HasAnyChanges()) { + should_shred = false; + } + if (!EnableShredding(config_options.variant_minimum_shredding_size, row_group.count.load())) { + should_shred = false; + } + + LogicalType shredded_type; + if (should_shred) { + if (config_options.force_variant_shredding.id() != LogicalTypeId::INVALID) { + shredded_type = config_options.force_variant_shredding; + } else { + shredded_type = GetShreddedType(); + } + D_ASSERT(shredded_type.id() == LogicalTypeId::STRUCT); + auto &type_entries = StructType::GetChildTypes(shredded_type); + if (type_entries.size() != 2) { + //! We couldn't determine a shredding type from the data + should_shred = false; + } + } + + if (!should_shred) { + for (idx_t i = 0; i < sub_columns.size(); i++) { + checkpoint_state->child_states.push_back(sub_columns[i]->Checkpoint(row_group, checkpoint_info)); + } + return std::move(checkpoint_state); + } + + //! STRUCT(unshredded VARIANT, shredded <...>) + BaseStatistics column_stats = BaseStatistics::CreateEmpty(shredded_type); + checkpoint_state->shredded_data = WriteShreddedData(row_group, shredded_type, column_stats); + D_ASSERT(checkpoint_state->shredded_data.size() == 2); + auto &unshredded = checkpoint_state->shredded_data[0]; + auto &shredded = checkpoint_state->shredded_data[1]; + + //! Now checkpoint the shredded data + checkpoint_state->child_states.push_back(unshredded->Checkpoint(row_group, checkpoint_info)); + checkpoint_state->child_states.push_back(shredded->Checkpoint(row_group, checkpoint_info)); + + return std::move(checkpoint_state); +} + +bool VariantColumnData::IsPersistent() { + if (!validity->IsPersistent()) { + return false; + } + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + if (!sub_column->IsPersistent()) { + return false; + } + } + return true; +} + +bool VariantColumnData::HasAnyChanges() const { + if (validity->HasAnyChanges()) { + return true; + } + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + if (sub_column->HasAnyChanges()) { + return true; + } + } + return false; +} + +PersistentColumnData VariantColumnData::Serialize() { + PersistentColumnData persistent_data(type); + if (IsShredded()) { + persistent_data.SetVariantShreddedType(sub_columns[1]->type); + } + persistent_data.child_columns.push_back(validity->Serialize()); + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + persistent_data.child_columns.push_back(sub_column->Serialize()); + } + return persistent_data; +} + +void VariantColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) { + validity->InitializeColumn(column_data.child_columns[0], target_stats); + + if (column_data.child_columns.size() == 3) { + //! This means the VARIANT is shredded + auto &unshredded_stats = VariantStats::GetUnshreddedStats(target_stats); + sub_columns[0]->InitializeColumn(column_data.child_columns[1], unshredded_stats); + + auto &shredded_type = column_data.variant_shredded_type; + if (!IsShredded()) { + VariantStats::SetShreddedStats(target_stats, BaseStatistics::CreateEmpty(shredded_type)); + sub_columns.push_back(ColumnData::CreateColumn(block_manager, info, 2, shredded_type, GetDataType(), this)); + } + auto &shredded_stats = VariantStats::GetShreddedStats(target_stats); + sub_columns[1]->InitializeColumn(column_data.child_columns[2], shredded_stats); + } else { + auto &unshredded_stats = VariantStats::GetUnshreddedStats(target_stats); + sub_columns[0]->InitializeColumn(column_data.child_columns[1], unshredded_stats); + } + this->count = validity->count.load(); +} + +void VariantColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { + col_path.push_back(0); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); + for (idx_t i = 0; i < sub_columns.size(); i++) { + col_path.back() = i + 1; + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); + } +} + +void VariantColumnData::Verify(RowGroup &parent) { +#ifdef DEBUG + ColumnData::Verify(parent); + validity->Verify(parent); + for (idx_t i = 0; i < sub_columns.size(); i++) { + auto &sub_column = sub_columns[i]; + sub_column->Verify(parent); + } +#endif +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table_index_list.cpp b/src/duckdb/src/storage/table_index_list.cpp index ade84cdc8..77f1f6581 100644 --- a/src/duckdb/src/storage/table_index_list.cpp +++ b/src/duckdb/src/storage/table_index_list.cpp @@ -147,11 +147,17 @@ void TableIndexList::Bind(ClientContext &context, DataTableInfo &table_info, con // Create an IndexBinder to bind the index IndexBinder idx_binder(*binder, context); - // Apply any outstanding appends and replace the unbound index with a bound index. + // Apply any outstanding buffered replays and replace the unbound index with a bound index. auto &unbound_index = index_entry->index->Cast(); auto bound_idx = idx_binder.BindIndex(unbound_index); - if (unbound_index.HasBufferedAppends()) { - bound_idx->ApplyBufferedAppends(column_types, unbound_index.GetBufferedAppends(), + if (unbound_index.HasBufferedReplays()) { + // For replaying buffered index operations, we only want the physical column types (skip over + // generated column types). + vector physical_column_types; + for (auto &col : table.GetColumns().Physical()) { + physical_column_types.push_back(col.Type()); + } + bound_idx->ApplyBufferedReplays(physical_column_types, unbound_index.GetBufferedReplays(), unbound_index.GetMappedColumnIds()); } @@ -255,11 +261,18 @@ void TableIndexList::InitializeIndexChunk(DataChunk &index_chunk, const vector index_types; + // Store the mapped_column_ids and index_types in sorted canonical form, needed for + // buffering WAL index operations during replay (see notes in unbound_index.hpp). + // First sort mapped_column_ids, then populate index_types according to the sorted order. for (auto &col : indexed_columns) { - index_types.push_back(table_types[col]); mapped_column_ids.emplace_back(col); } + std::sort(mapped_column_ids.begin(), mapped_column_ids.end()); + + vector index_types; + for (auto &col : mapped_column_ids) { + index_types.push_back(table_types[col.GetPrimaryIndex()]); + } index_chunk.InitializeEmpty(index_types); } diff --git a/src/duckdb/src/storage/temporary_file_manager.cpp b/src/duckdb/src/storage/temporary_file_manager.cpp index b8ab5a7b0..e27ef0729 100644 --- a/src/duckdb/src/storage/temporary_file_manager.cpp +++ b/src/duckdb/src/storage/temporary_file_manager.cpp @@ -73,7 +73,6 @@ TemporaryFileIdentifier::TemporaryFileIdentifier(TemporaryBufferSize size_p, idx TemporaryFileIdentifier::TemporaryFileIdentifier(DatabaseInstance &db, TemporaryBufferSize size_p, idx_t file_index_p, bool encrypted_p) : size(size_p), file_index(file_index_p), encrypted(encrypted_p) { - if (encrypted) { // generate a random encryption key ID and corresponding key EncryptionEngine::AddTempKeyToCache(db); diff --git a/src/duckdb/src/storage/wal_replay.cpp b/src/duckdb/src/storage/wal_replay.cpp index 77eca9cf7..b5c8e0ada 100644 --- a/src/duckdb/src/storage/wal_replay.cpp +++ b/src/duckdb/src/storage/wal_replay.cpp @@ -32,6 +32,8 @@ #include "duckdb/storage/table/delete_state.hpp" #include "duckdb/storage/write_ahead_log.hpp" #include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -46,6 +48,8 @@ class ReplayState { optional_ptr current_table; MetaBlockPointer checkpoint_id; idx_t wal_version = 1; + optional_idx current_position; + optional_idx checkpoint_position; optional_idx expected_checkpoint_id; struct ReplayIndexInfo { @@ -75,8 +79,8 @@ class WriteAheadLogDeserializer { deserializer.Set(catalog); } - static WriteAheadLogDeserializer Open(ReplayState &state_p, BufferedFileReader &stream, - bool deserialize_only = false) { + static WriteAheadLogDeserializer GetEntryDeserializer(ReplayState &state_p, BufferedFileReader &stream, + bool deserialize_only = false) { if (state_p.wal_version == 1) { // old WAL versions do not have checksums return WriteAheadLogDeserializer(state_p, stream, deserialize_only); @@ -256,42 +260,65 @@ class WriteAheadLogDeserializer { //===--------------------------------------------------------------------===// // Replay //===--------------------------------------------------------------------===// -unique_ptr WriteAheadLog::Replay(FileSystem &fs, AttachedDatabase &db, const string &wal_path) { +unique_ptr WriteAheadLog::Replay(QueryContext context, StorageManager &storage_manager, + const string &wal_path) { + auto &fs = FileSystem::Get(storage_manager.GetAttached()); auto handle = fs.OpenFile(wal_path, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS); if (!handle) { // WAL does not exist - instantiate an empty WAL - return make_uniq(db, wal_path); + return make_uniq(storage_manager, wal_path); } - auto wal_handle = ReplayInternal(db, std::move(handle)); + + // context is passed for metric collection purposes only!! + auto wal_handle = ReplayInternal(context, storage_manager, std::move(handle)); if (wal_handle) { return wal_handle; } // replay returning NULL indicates we can nuke the WAL entirely - but only if this is not a read-only connection - if (!db.IsReadOnly()) { - fs.RemoveFile(wal_path); + if (!storage_manager.GetAttached().IsReadOnly()) { + fs.TryRemoveFile(wal_path); + } + return make_uniq(storage_manager, wal_path); +} + +static void CopyOverWAL(QueryContext context, BufferedFileReader &reader, FileHandle &target, data_ptr_t buffer, + idx_t buffer_size, idx_t copy_end) { + while (!reader.Finished()) { + idx_t read_count = MinValue(buffer_size, copy_end - reader.CurrentOffset()); + if (read_count == 0) { + break; + } + reader.ReadData(context, buffer, read_count); + + target.Write(buffer, read_count); } - return make_uniq(db, wal_path); } -unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &database, unique_ptr handle) { + +unique_ptr WriteAheadLog::ReplayInternal(QueryContext context, StorageManager &storage_manager, + unique_ptr handle, WALReplayState replay_state) { + auto &database = storage_manager.GetAttached(); Connection con(database.GetDatabase()); auto wal_path = handle->GetPath(); BufferedFileReader reader(FileSystem::Get(database), std::move(handle)); if (reader.Finished()) { - // WAL file exists but it is empty - we can delete the file + // WAL file exists, but it is empty - we can delete the file return nullptr; } con.BeginTransaction(); - MetaTransaction::Get(*con.context).ModifyDatabase(database); + MetaTransaction::Get(*con.context).ModifyDatabase(database, DatabaseModificationType()); auto &config = DBConfig::GetConfig(database.GetDatabase()); // first deserialize the WAL to look for a checkpoint flag // if there is a checkpoint flag, we might have already flushed the contents of the WAL to disk ReplayState checkpoint_state(database, *con.context); try { + idx_t replay_entry_count = 0; while (true) { + replay_entry_count++; // read the current entry (deserialize only) - auto deserializer = WriteAheadLogDeserializer::Open(checkpoint_state, reader, true); + checkpoint_state.current_position = reader.CurrentOffset(); + auto deserializer = WriteAheadLogDeserializer::GetEntryDeserializer(checkpoint_state, reader, true); if (deserializer.ReplayEntry()) { // check if the file is exhausted if (reader.Finished()) { @@ -300,6 +327,11 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa } } } + auto client_context = context.GetClientContext(); + if (client_context) { + auto &profiler = *client_context->client_data->profiler; + profiler.AddToCounter(MetricType::WAL_REPLAY_ENTRY_COUNT, replay_entry_count); + } } catch (std::exception &ex) { // LCOV_EXCL_START ErrorData error(ex); // ignore serialization exceptions - they signal a torn WAL @@ -307,13 +339,115 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa error.Throw("Failure while replaying WAL file \"" + wal_path + "\": "); } } // LCOV_EXCL_STOP + unique_ptr checkpoint_handle; if (checkpoint_state.checkpoint_id.IsValid()) { - // there is a checkpoint flag: check if we need to deserialize the WAL + if (replay_state == WALReplayState::CHECKPOINT_WAL) { + throw InvalidInputException( + "Failure while replaying checkpoint WAL file \"%s\": checkpoint WAL cannot contain a checkpoint marker", + wal_path); + } + // there is a checkpoint flag + // this means a checkpoint was on-going when we crashed + // we need to reconcile this with what is in the data file + // first check if there is a checkpoint WAL auto &manager = database.GetStorageManager(); - if (manager.IsCheckpointClean(checkpoint_state.checkpoint_id)) { - // the contents of the WAL have already been checkpointed - // we can safely truncate the WAL and ignore its contents - return nullptr; + auto &fs = FileSystem::Get(storage_manager.GetAttached()); + auto checkpoint_wal = manager.GetCheckpointWALPath(); + checkpoint_handle = + fs.OpenFile(checkpoint_wal, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS); + bool checkpoint_was_successful = manager.IsCheckpointClean(checkpoint_state.checkpoint_id); + if (!checkpoint_handle) { + // no checkpoint WAL - either we just need to replay this WAL, or we are done + if (checkpoint_was_successful) { + // the contents of the WAL have already been checkpointed and there is no checkpoint WAL - we are done + return nullptr; + } + } else { + // we have a checkpoint WAL + if (checkpoint_was_successful) { + // the checkpoint was successful + // the main WAL is no longer needed, we only need to replay the checkpoint WAL + // if this is a read-only connection then replay the checkpoint WAL directly + if (storage_manager.GetAttached().IsReadOnly()) { + return ReplayInternal(context, storage_manager, std::move(checkpoint_handle), + WALReplayState::CHECKPOINT_WAL); + } + // if this is not a read-only connection we need to finish the checkpoint + // overwrite the current WAL with the checkpoint WAL + checkpoint_handle.reset(); + + fs.MoveFile(checkpoint_wal, wal_path); + + // now open the handle again and replay the checkpoint WAL + checkpoint_handle = + fs.OpenFile(wal_path, FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS); + return ReplayInternal(context, storage_manager, std::move(checkpoint_handle), + WALReplayState::CHECKPOINT_WAL); + } + // the checkpoint was unsuccessful + // this means we need to replay both this WAL and the checkpoint WAL + // if this is a read-only connection - replay both WAL files + if (!storage_manager.GetAttached().IsReadOnly()) { + // if this is not a read-only connection, then merge the two WALs and replay the merged WAL + // we merge into the recovery WAL path + auto recovery_path = manager.GetRecoveryWALPath(); + auto recovery_handle = + fs.OpenFile(recovery_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); + + static constexpr idx_t BATCH_SIZE = Storage::DEFAULT_BLOCK_SIZE; + auto buffer = make_uniq_array(BATCH_SIZE); + + // first copy over the main WAL contents + auto copy_end = checkpoint_state.checkpoint_position.GetIndex(); + reader.Reset(); + CopyOverWAL(context, reader, *recovery_handle, buffer.get(), BATCH_SIZE, copy_end); + + // now copy over the checkpoint WAL + { + BufferedFileReader checkpoint_reader(FileSystem::Get(database), std::move(checkpoint_handle)); + + // skip over the version entry + ReplayState checkpoint_replay_state(database, *con.context); + auto deserializer = WriteAheadLogDeserializer::GetEntryDeserializer(checkpoint_replay_state, + checkpoint_reader, true); + deserializer.ReplayEntry(); + + if (checkpoint_replay_state.wal_version != checkpoint_state.wal_version) { + throw InvalidInputException("Failure while replaying checkpoint WAL file \"%s\": checkpoint " + "WAL version is different from main WAL version", + wal_path); + } + + CopyOverWAL(context, checkpoint_reader, *recovery_handle, buffer.get(), BATCH_SIZE, + checkpoint_reader.FileSize()); + } + + auto debug_checkpoint_abort = + DBConfig::GetSetting(storage_manager.GetDatabase()); + + // move over the recovery WAL over the main WAL + recovery_handle->Sync(); + recovery_handle.reset(); + + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_MOVING_RECOVERY) { + throw FatalException( + "Checkpoint aborted before moving recovery file because of PRAGMA checkpoint_abort flag"); + } + + fs.MoveFile(recovery_path, wal_path); + + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_DELETING_CHECKPOINT_WAL) { + throw FatalException( + "Checkpoint aborted before deleting checkpoint file because of PRAGMA checkpoint_abort flag"); + } + + // delete the checkpoint WAL + fs.RemoveFile(checkpoint_wal); + + // replay the (combined) recovery WAL + auto main_handle = fs.OpenFile(wal_path, FileFlags::FILE_FLAGS_READ); + return ReplayInternal(context, storage_manager, std::move(main_handle), WALReplayState::CHECKPOINT_WAL); + } } } if (checkpoint_state.expected_checkpoint_id.IsValid()) { @@ -336,7 +470,7 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa try { while (true) { // read the current entry - auto deserializer = WriteAheadLogDeserializer::Open(state, reader); + auto deserializer = WriteAheadLogDeserializer::GetEntryDeserializer(state, reader); if (deserializer.ReplayEntry()) { con.Commit(); @@ -354,7 +488,7 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa break; } con.BeginTransaction(); - MetaTransaction::Get(*con.context).ModifyDatabase(database); + MetaTransaction::Get(*con.context).ModifyDatabase(database, DatabaseModificationType()); } } } catch (std::exception &ex) { // LCOV_EXCL_START @@ -372,8 +506,14 @@ unique_ptr WriteAheadLog::ReplayInternal(AttachedDatabase &databa con.Query("ROLLBACK"); throw; } // LCOV_EXCL_STOP + if (all_succeeded && checkpoint_handle) { + // we have successfully replayed the main WAL - but there is still a checkpoint WAL remaining + // this can only happen in read-only mode + // replay the checkpoint WAL and return + return ReplayInternal(context, storage_manager, std::move(checkpoint_handle), WALReplayState::CHECKPOINT_WAL); + } auto init_state = all_succeeded ? WALInitState::UNINITIALIZED : WALInitState::UNINITIALIZED_REQUIRES_TRUNCATE; - return make_uniq(database, wal_path, successful_offset, init_state); + return make_uniq(storage_manager, wal_path, successful_offset, init_state); } //===--------------------------------------------------------------------===// @@ -562,7 +702,6 @@ void WriteAheadLogDeserializer::ReplayIndexData(IndexStorageInfo &info) { // Read the data into buffer handles and convert them to blocks on disk. for (idx_t j = 0; j < data_info.allocation_sizes.size(); j++) { - // Read the data into a buffer handle. auto buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.get(), false); auto block_handle = buffer_handle.GetBlockHandle(); @@ -572,8 +711,8 @@ void WriteAheadLogDeserializer::ReplayIndexData(IndexStorageInfo &info) { // Convert the buffer handle to a persistent block and store the block id. if (!deserialize_only) { - auto block_id = block_manager->GetFreeBlockId(); - block_manager->ConvertToPersistent(QueryContext(context), block_id, std::move(block_handle), + auto block_id = block_manager->GetFreeBlockIdForCheckpoint(); + block_manager->ConvertToPersistent(context, block_id, std::move(block_handle), std::move(buffer_handle)); data_info.block_pointers[j].block_id = block_id; } @@ -942,21 +1081,20 @@ void WriteAheadLogDeserializer::ReplayDelete() { } D_ASSERT(chunk.ColumnCount() == 1 && chunk.data[0].GetType() == LogicalType::ROW_TYPE); - row_t row_ids[1]; - Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_ids)); - auto source_ids = FlatVector::GetData(chunk.data[0]); + auto &row_identifiers = chunk.data[0]; + row_identifiers.Flatten(chunk.size()); + auto source_ids = FlatVector::GetData(row_identifiers); // Delete the row IDs from the current table. auto &storage = state.current_table->GetStorage(); auto total_rows = storage.GetTotalRows(); - TableDeleteState delete_state; for (idx_t i = 0; i < chunk.size(); i++) { if (source_ids[i] >= UnsafeNumericCast(total_rows)) { throw SerializationException("invalid row ID delete in WAL"); } - row_ids[0] = source_ids[i]; - storage.Delete(delete_state, context, row_identifiers, 1); } + TableDeleteState delete_state; + storage.Delete(delete_state, context, row_identifiers, chunk.size()); } void WriteAheadLogDeserializer::ReplayUpdate() { @@ -986,6 +1124,7 @@ void WriteAheadLogDeserializer::ReplayUpdate() { void WriteAheadLogDeserializer::ReplayCheckpoint() { state.checkpoint_id = deserializer.ReadProperty(101, "meta_block"); + state.checkpoint_position = state.current_position; } } // namespace duckdb diff --git a/src/duckdb/src/storage/write_ahead_log.cpp b/src/duckdb/src/storage/write_ahead_log.cpp index 57689386a..cc565bc96 100644 --- a/src/duckdb/src/storage/write_ahead_log.cpp +++ b/src/duckdb/src/storage/write_ahead_log.cpp @@ -28,16 +28,18 @@ namespace duckdb { constexpr uint64_t WAL_VERSION_NUMBER = 2; constexpr uint64_t WAL_ENCRYPTED_VERSION_NUMBER = 3; -WriteAheadLog::WriteAheadLog(AttachedDatabase &database, const string &wal_path, idx_t wal_size, - WALInitState init_state) - : database(database), wal_path(wal_path), wal_size(wal_size), init_state(init_state) { +WriteAheadLog::WriteAheadLog(StorageManager &storage_manager, const string &wal_path, idx_t wal_size, + WALInitState init_state, optional_idx checkpoint_iteration) + : storage_manager(storage_manager), wal_path(wal_path), init_state(init_state), + checkpoint_iteration(checkpoint_iteration) { + storage_manager.SetWALSize(wal_size); } WriteAheadLog::~WriteAheadLog() { } AttachedDatabase &WriteAheadLog::GetDatabase() { - return database; + return storage_manager.GetAttached(); } BufferedFileWriter &WriteAheadLog::Initialize() { @@ -47,24 +49,19 @@ BufferedFileWriter &WriteAheadLog::Initialize() { lock_guard lock(wal_lock); if (!writer) { writer = - make_uniq(FileSystem::Get(database), wal_path, + make_uniq(FileSystem::Get(GetDatabase()), wal_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE | FileFlags::FILE_FLAGS_APPEND | FileFlags::FILE_FLAGS_MULTI_CLIENT_ACCESS); if (init_state == WALInitState::UNINITIALIZED_REQUIRES_TRUNCATE) { - writer->Truncate(wal_size); + writer->Truncate(storage_manager.GetWALSize()); + } else { + storage_manager.SetWALSize(writer->GetFileSize()); } - wal_size = writer->GetFileSize(); init_state = WALInitState::INITIALIZED; } return *writer; } -//! Gets the total bytes written to the WAL since startup -idx_t WriteAheadLog::GetWALSize() const { - D_ASSERT(init_state != WALInitState::NO_WAL || wal_size == 0); - return wal_size; -} - idx_t WriteAheadLog::GetTotalWritten() const { if (!Initialized()) { return 0; @@ -79,29 +76,17 @@ void WriteAheadLog::Truncate(idx_t size) { } if (!Initialized()) { init_state = WALInitState::UNINITIALIZED_REQUIRES_TRUNCATE; - wal_size = size; + storage_manager.SetWALSize(size); return; } writer->Truncate(size); - wal_size = writer->GetFileSize(); + storage_manager.SetWALSize(writer->GetFileSize()); } bool WriteAheadLog::Initialized() const { return init_state == WALInitState::INITIALIZED; } -void WriteAheadLog::Delete() { - if (init_state == WALInitState::NO_WAL) { - // no WAL to delete - return; - } - writer.reset(); - auto &fs = FileSystem::Get(database); - fs.TryRemoveFile(wal_path); - init_state = WALInitState::NO_WAL; - wal_size = 0; -} - //===--------------------------------------------------------------------===// // Serializer //===--------------------------------------------------------------------===// @@ -254,6 +239,7 @@ void WriteAheadLog::WriteHeader() { serializer.Begin(); serializer.WriteProperty(100, "wal_type", WALType::WAL_VERSION); + auto &database = GetDatabase(); auto &catalog = database.GetCatalog().Cast(); auto encryption_version_number = catalog.GetIsEncrypted() ? idx_t(WAL_ENCRYPTED_VERSION_NUMBER) : idx_t(WAL_VERSION_NUMBER); @@ -265,8 +251,13 @@ void WriteAheadLog::WriteHeader() { auto db_identifier = single_file_block_manager.GetDBIdentifier(); serializer.WriteList(102, "db_identifier", MainHeader::DB_IDENTIFIER_LEN, [&](Serializer::List &list, idx_t i) { list.WriteElement(db_identifier[i]); }); - auto checkpoint_iteration = single_file_block_manager.GetCheckpointIteration(); - serializer.WriteProperty(103, "checkpoint_iteration", checkpoint_iteration); + idx_t current_checkpoint_iteration; + if (checkpoint_iteration.IsValid()) { + current_checkpoint_iteration = checkpoint_iteration.GetIndex(); + } else { + current_checkpoint_iteration = single_file_block_manager.GetCheckpointIteration(); + } + serializer.WriteProperty(103, "checkpoint_iteration", current_checkpoint_iteration); } serializer.End(); @@ -399,6 +390,7 @@ void WriteAheadLog::WriteCreateIndex(const IndexCatalogEntry &entry) { // Serialize the index data to the persistent storage and write the metadata. auto &index_entry = entry.Cast(); auto &list = index_entry.GetDataTableInfo().GetIndexes(); + auto &database = GetDatabase(); SerializeIndex(database, serializer, list, index_entry.name); serializer.End(); } @@ -521,6 +513,7 @@ void WriteAheadLog::WriteAlter(CatalogEntry &entry, const AlterInfo &info) { auto &list = parent_info->GetIndexes(); auto name = unique.GetName(parent.name); + auto &database = GetDatabase(); SerializeIndex(database, serializer, list, name); serializer.End(); } @@ -539,7 +532,7 @@ void WriteAheadLog::Flush() { // flushes all changes made to the WAL to disk writer->Sync(); - wal_size = writer->GetFileSize(); + storage_manager.SetWALSize(writer->GetFileSize()); } } // namespace duckdb diff --git a/src/duckdb/src/transaction/cleanup_state.cpp b/src/duckdb/src/transaction/cleanup_state.cpp index f9a17f265..1a07bf6ee 100644 --- a/src/duckdb/src/transaction/cleanup_state.cpp +++ b/src/duckdb/src/transaction/cleanup_state.cpp @@ -10,15 +10,14 @@ #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/table/update_segment.hpp" #include "duckdb/storage/table/row_version_manager.hpp" +#include "duckdb/transaction/commit_state.hpp" namespace duckdb { -CleanupState::CleanupState(transaction_t lowest_active_transaction) - : lowest_active_transaction(lowest_active_transaction), current_table(nullptr), count(0) { -} - -CleanupState::~CleanupState() { - Flush(); +CleanupState::CleanupState(const QueryContext &context, transaction_t lowest_active_transaction, + ActiveTransactionState transaction_state) + : lowest_active_transaction(lowest_active_transaction), transaction_state(transaction_state), + index_data_remover(context, IndexRemovalType::DELETED_ROWS_IN_USE) { } void CleanupState::CleanupEntry(UndoFlags type, data_ptr_t data) { @@ -58,50 +57,12 @@ void CleanupState::CleanupUpdate(UpdateInfo &info) { } void CleanupState::CleanupDelete(DeleteInfo &info) { - auto version_table = info.table; - if (!version_table->HasIndexes()) { - // this table has no indexes: no cleanup to be done + if (transaction_state == ActiveTransactionState::NO_OTHER_TRANSACTIONS) { + // if there are no active transactions we don't need to do any clean-up, as we haven't written to + // deleted_rows_in_use return; } - - if (current_table != version_table) { - // table for this entry differs from previous table: flush and switch to the new table - Flush(); - current_table = version_table; - } - - // possibly vacuum any indexes in this table later - indexed_tables[current_table->GetTableName()] = current_table; - - count = 0; - if (info.is_consecutive) { - for (idx_t i = 0; i < info.count; i++) { - row_numbers[count++] = UnsafeNumericCast(info.base_row + i); - } - } else { - auto rows = info.GetRows(); - for (idx_t i = 0; i < info.count; i++) { - row_numbers[count++] = UnsafeNumericCast(info.base_row + rows[i]); - } - } - Flush(); -} - -void CleanupState::Flush() { - if (count == 0) { - return; - } - - // set up the row identifiers vector - Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_numbers)); - - // delete the tuples from all the indexes - try { - current_table->RemoveFromIndexes(row_identifiers, count); - } catch (...) { // NOLINT: ignore errors here - } - - count = 0; + index_data_remover.PushDelete(info); } } // namespace duckdb diff --git a/src/duckdb/src/transaction/commit_state.cpp b/src/duckdb/src/transaction/commit_state.cpp index 0f5d75bd2..1819e0c46 100644 --- a/src/duckdb/src/transaction/commit_state.cpp +++ b/src/duckdb/src/transaction/commit_state.cpp @@ -21,8 +21,95 @@ namespace duckdb { -CommitState::CommitState(DuckTransaction &transaction_p, transaction_t commit_id) - : transaction(transaction_p), commit_id(commit_id) { +//===--------------------------------------------------------------------===// +// IndexDataRemover +//===--------------------------------------------------------------------===// +IndexDataRemover::IndexDataRemover(QueryContext context, IndexRemovalType removal_type) + : context(context), removal_type(removal_type) { +} + +void IndexDataRemover::PushDelete(DeleteInfo &info) { + auto &version_table = *info.table; + if (!version_table.HasIndexes()) { + // this table has no indexes: no cleanup to be done + return; + } + + idx_t count = 0; + row_t row_numbers[STANDARD_VECTOR_SIZE]; + if (info.is_consecutive) { + for (idx_t i = 0; i < info.count; i++) { + row_numbers[count++] = UnsafeNumericCast(info.base_row + i); + } + } else { + auto rows = info.GetRows(); + for (idx_t i = 0; i < info.count; i++) { + row_numbers[count++] = UnsafeNumericCast(info.base_row + rows[i]); + } + } + Flush(version_table, row_numbers, count); +} + +void IndexDataRemover::Verify() { +#ifdef DEBUG + // Verify that our index memory is stable. + for (auto &table : verify_indexes) { + table.second->VerifyIndexBuffers(); + } +#endif +} + +void CommitState::Verify() { + index_data_remover.Verify(); +} + +void IndexDataRemover::Flush(DataTable &table, row_t *row_numbers, idx_t count) { + if (count == 0) { + return; + } +#ifdef DEBUG + verify_indexes.insert(make_pair(reference(table), table.GetDataTableInfo())); +#endif + + // set up the row identifiers vector + Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_numbers)); + + // delete the tuples from all the indexes. + // If there is any issue with removal, a FatalException must be thrown since there may be a corruption of + // data, hence the transaction cannot be guaranteed. + try { + table.RemoveFromIndexes(context, row_identifiers, count, removal_type); + } catch (std::exception &ex) { + throw FatalException(ErrorData(ex).Message()); + } catch (...) { + throw FatalException("unknown failure in CommitState::Flush"); + } + + count = 0; +} + +//===--------------------------------------------------------------------===// +// CommitState +//===--------------------------------------------------------------------===// +CommitState::CommitState(DuckTransaction &transaction_p, transaction_t commit_id, + ActiveTransactionState transaction_state, CommitMode commit_mode) + : transaction(transaction_p), commit_id(commit_id), + index_data_remover(*transaction.context.lock(), GetIndexRemovalType(transaction_state, commit_mode)) { +} + +IndexRemovalType CommitState::GetIndexRemovalType(ActiveTransactionState transaction_state, CommitMode commit_mode) { + if (commit_mode == CommitMode::COMMIT) { + if (transaction_state == ActiveTransactionState::NO_OTHER_TRANSACTIONS) { + // if there are no other active transactions we don't need to store removed rows in deleted_rows_in_use + return IndexRemovalType::MAIN_INDEX_ONLY; + } + return IndexRemovalType::MAIN_INDEX; + } + // revert the appends to the indexes + if (transaction_state == ActiveTransactionState::NO_OTHER_TRANSACTIONS) { + return IndexRemovalType::REVERT_MAIN_INDEX_ONLY; + } + return IndexRemovalType::REVERT_MAIN_INDEX; } void CommitState::CommitEntryDrop(CatalogEntry &entry, data_ptr_t dataptr) { @@ -165,6 +252,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::INSERT_TUPLE: { // append: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->table->CommitAppend(commit_id, info->start_row, info->count); break; @@ -172,13 +265,24 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::DELETE_TUPLE: { // deletion: auto info = reinterpret_cast(data); - // mark the tuples as committed - info->version_info->CommitDelete(info->vector_idx, commit_id, *info); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } + CommitDelete(*info); break; } case UndoFlags::UPDATE_TUPLE: { // update: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } info->version_number = commit_id; break; } @@ -191,6 +295,13 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { } } +void CommitState::CommitDelete(DeleteInfo &info) { + // mark the tuples as committed + info.version_info->CommitDelete(info.vector_idx, commit_id, info); + // delete from indexes + index_data_remover.PushDelete(info); +} + void CommitState::RevertCommit(UndoFlags type, data_ptr_t data) { transaction_t transaction_id = commit_id; switch (type) { @@ -214,7 +325,7 @@ void CommitState::RevertCommit(UndoFlags type, data_ptr_t data) { // deletion: auto info = reinterpret_cast(data); // revert the commit by writing the (uncommitted) transaction_id back into the version info - info->version_info->CommitDelete(info->vector_idx, transaction_id, *info); + CommitDelete(*info); break; } case UndoFlags::UPDATE_TUPLE: { diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index dc6afccb7..fafd62d9e 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -32,8 +32,8 @@ TransactionData::TransactionData(transaction_t transaction_id_p, transaction_t s DuckTransaction::DuckTransaction(DuckTransactionManager &manager, ClientContext &context_p, transaction_t start_time, transaction_t transaction_id, idx_t catalog_version_p) : Transaction(manager, context_p), start_time(start_time), transaction_id(transaction_id), commit_id(0), - highest_active_query(0), catalog_version(catalog_version_p), awaiting_cleanup(false), - transaction_manager(manager), undo_buffer(*this, context_p), storage(make_uniq(context_p, *this)) { + catalog_version(catalog_version_p), awaiting_cleanup(false), undo_buffer(*this, context_p), + storage(make_uniq(context_p, *this)) { } DuckTransaction::~DuckTransaction() { @@ -51,6 +51,10 @@ DuckTransaction &DuckTransaction::Get(ClientContext &context, Catalog &catalog) return transaction.Cast(); } +DuckTransactionManager &DuckTransaction::GetTransactionManager() { + return manager.Cast(); +} + LocalStorage &DuckTransaction::GetLocalStorage() { return *storage; } @@ -126,11 +130,12 @@ void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_co append_info->count = row_count; } -UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { +UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries, + idx_t row_group_start) { idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto undo_entry = undo_buffer.CreateEntry(UndoFlags::UPDATE_TUPLE, alloc_size); auto &update_info = UpdateInfo::Get(undo_entry); - UpdateInfo::Initialize(update_info, transaction_id); + UpdateInfo::Initialize(update_info, data_table, transaction_id, row_group_start); return undo_entry; } @@ -196,22 +201,28 @@ bool DuckTransaction::ShouldWriteToWAL(AttachedDatabase &db) { return false; } auto &storage_manager = db.GetStorageManager(); - auto log = storage_manager.GetWAL(); - if (!log) { + if (!storage_manager.HasWAL()) { return false; } return true; } -ErrorData DuckTransaction::WriteToWAL(AttachedDatabase &db, unique_ptr &commit_state) noexcept { +ErrorData DuckTransaction::WriteToWAL(ClientContext &context, AttachedDatabase &db, + unique_ptr &commit_state) noexcept { ErrorData error_data; try { D_ASSERT(ShouldWriteToWAL(db)); auto &storage_manager = db.GetStorageManager(); - auto log = storage_manager.GetWAL(); - commit_state = storage_manager.GenStorageCommitState(*log); + auto wal = storage_manager.GetWAL(); + commit_state = storage_manager.GenStorageCommitState(*wal); + + auto &profiler = *context.client_data->profiler; + + auto commit_timer = profiler.StartTimer(MetricType::COMMIT_LOCAL_STORAGE_LATENCY); storage->Commit(commit_state.get()); - undo_buffer.WriteToWAL(*log, commit_state.get()); + + auto wal_timer = profiler.StartTimer(MetricType::WRITE_TO_WAL_LATENCY); + undo_buffer.WriteToWAL(*wal, commit_state.get()); if (commit_state->HasRowGroupData()) { // if we have optimistically written any data AND we are writing to the WAL, we have written references to // optimistically written blocks @@ -235,31 +246,19 @@ ErrorData DuckTransaction::WriteToWAL(AttachedDatabase &db, unique_ptr commit_state) noexcept { - // "checkpoint" parameter indicates if the caller will checkpoint. If checkpoint == - // true: Then this function will NOT write to the WAL or flush/persist. - // This method only makes commit in memory, expecting caller to checkpoint/flush. - // false: Then this function WILL write to the WAL and Flush/Persist it. - this->commit_id = new_commit_id; + this->commit_id = commit_info.commit_id; if (!ChangesMade()) { // no need to flush anything if we made no changes return ErrorData(); } - for (auto &entry : modified_tables) { - auto &tbl = entry.first.get(); - if (!tbl.IsMainTable()) { - return ErrorData( - TransactionException("Attempting to modify table %s but another transaction has %s this table", - tbl.GetTableName(), tbl.TableModification())); - } - } D_ASSERT(db.IsSystem() || db.IsTemporary() || !IsReadOnly()); UndoBuffer::IteratorState iterator_state; try { storage->Commit(commit_state.get()); - undo_buffer.Commit(iterator_state, commit_id); + undo_buffer.Commit(iterator_state, commit_info); if (commit_state) { // if we have written to the WAL - flush after the commit has been successful commit_state->FlushCommit(); @@ -289,17 +288,33 @@ void DuckTransaction::Cleanup(transaction_t lowest_active_transaction) { undo_buffer.Cleanup(lowest_active_transaction); } -void DuckTransaction::SetReadWrite() { - Transaction::SetReadWrite(); - // obtain a shared checkpoint lock to prevent concurrent checkpoints while this transaction is running - write_lock = transaction_manager.SharedCheckpointLock(); +void DuckTransaction::SetModifications(DatabaseModificationType type) { + if (write_lock) { + // already have a write lock + return; + } + bool require_write_lock = false; + require_write_lock = require_write_lock || type.InsertDataWithIndex(); + require_write_lock = require_write_lock || type.DeleteData(); + require_write_lock = require_write_lock || type.UpdateData(); + require_write_lock = require_write_lock || type.AlterTable(); + require_write_lock = require_write_lock || type.CreateCatalogEntry(); + require_write_lock = require_write_lock || type.DropCatalogEntry(); + require_write_lock = require_write_lock || type.Sequence(); + require_write_lock = require_write_lock || type.CreateIndex(); + + if (require_write_lock) { + // obtain a shared checkpoint lock to prevent concurrent checkpoints while this transaction is running + write_lock = GetTransactionManager().SharedCheckpointLock(); + } } unique_ptr DuckTransaction::TryGetCheckpointLock() { if (!write_lock) { - throw InternalException("TryUpgradeCheckpointLock - but thread has no shared lock!?"); + return GetTransactionManager().TryGetCheckpointLock(); + } else { + return GetTransactionManager().TryUpgradeCheckpointLock(*write_lock); } - return transaction_manager.TryUpgradeCheckpointLock(*write_lock); } shared_ptr DuckTransaction::SharedLockTable(DataTableInfo &info) { diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp index eace5283c..4b70e4e15 100644 --- a/src/duckdb/src/transaction/duck_transaction_manager.cpp +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -1,5 +1,7 @@ #include "duckdb/transaction/duck_transaction_manager.hpp" +#include "duckdb/main/client_data.hpp" + #include "duckdb/catalog/catalog_set.hpp" #include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/common/exception.hpp" @@ -40,6 +42,7 @@ DuckTransactionManager::DuckTransactionManager(AttachedDatabase &db) : Transacti current_transaction_id = TRANSACTION_ID_START; lowest_active_id = TRANSACTION_ID_START; lowest_active_start = MAX_TRANSACTION_ID; + active_checkpoint = MAX_TRANSACTION_ID; if (!db.GetCatalog().IsDuckCatalog()) { // Specifically the StorageManager of the DuckCatalog is relied on, with `db.GetStorageManager` throw InternalException("DuckTransactionManager should only be created together with a DuckCatalog"); @@ -87,6 +90,27 @@ Transaction &DuckTransactionManager::StartTransaction(ClientContext &context) { return transaction_ref; } +ActiveCheckpointWrapper::ActiveCheckpointWrapper(DuckTransactionManager &manager) : manager(manager) { +} + +ActiveCheckpointWrapper::~ActiveCheckpointWrapper() { + manager.ResetCheckpointId(); +} + +transaction_t DuckTransactionManager::GetNewCheckpointId() { + if (active_checkpoint != MAX_TRANSACTION_ID) { + throw InternalException( + "DuckTransactionManager::GetNewCheckpointId requested a new id but active_checkpoint was already set"); + } + auto result = last_commit.load(); + active_checkpoint = result; + return result; +} + +void DuckTransactionManager::ResetCheckpointId() { + active_checkpoint = MAX_TRANSACTION_ID; +} + DuckTransactionManager::CheckpointDecision::CheckpointDecision(string reason_p) : can_checkpoint(false), reason(std::move(reason_p)) { } @@ -97,6 +121,15 @@ DuckTransactionManager::CheckpointDecision::CheckpointDecision(CheckpointType ty DuckTransactionManager::CheckpointDecision::~CheckpointDecision() { } +bool DuckTransactionManager::HasOtherTransactions(DuckTransaction &transaction) { + for (auto &active_transaction : active_transactions) { + if (!RefersToSameObject(*active_transaction, transaction)) { + return true; + } + } + return false; +} + DuckTransactionManager::CheckpointDecision DuckTransactionManager::CanCheckpoint(DuckTransaction &transaction, unique_ptr &lock, const UndoBufferProperties &undo_properties) { @@ -123,13 +156,7 @@ DuckTransactionManager::CanCheckpoint(DuckTransaction &transaction, unique_ptr DuckTransactionManager::SharedCheckpointLock() { @@ -227,10 +254,12 @@ unique_ptr DuckTransactionManager::TryUpgradeCheckpointLock(Stor return checkpoint_lock.TryUpgradeCheckpointLock(lock); } +unique_ptr DuckTransactionManager::TryGetCheckpointLock() { + return checkpoint_lock.TryGetExclusiveLock(); +} + transaction_t DuckTransactionManager::GetCommitTimestamp() { - auto commit_ts = current_start_timestamp++; - last_commit = commit_ts; - return commit_ts; + return current_start_timestamp++; } ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Transaction &transaction_p) { @@ -253,22 +282,19 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran unique_ptr> held_wal_lock; unique_ptr commit_state; if (!checkpoint_decision.can_checkpoint && transaction.ShouldWriteToWAL(db)) { + auto &storage_manager = db.GetStorageManager().Cast(); // if we are committing changes and we are not checkpointing, we need to write to the WAL // since WAL writes can take a long time - we grab the WAL lock here and unlock the transaction lock // read-only transactions can bypass this branch and start/commit while the WAL write is happening - if (!transaction.HasWriteLock()) { - // sanity check - this transaction should have a write lock - // the write lock prevents other transactions from checkpointing until this transaction is fully finished - // if we do not hold the write lock here, other transactions can bypass this branch by auto-checkpoint - // this would lead to a checkpoint WHILE this thread is writing to the WAL - // this should never happen - throw InternalException("Transaction writing to WAL does not have the write lock"); - } // unlock the transaction lock while we write to the WAL t_lock.unlock(); // grab the WAL lock and hold it until the entire commit is finished - held_wal_lock = make_uniq>(wal_lock); - error = transaction.WriteToWAL(db, commit_state); + held_wal_lock = storage_manager.GetWALLock(); + + // Commit the changes to the WAL. + if (db.GetRecoveryMode() == RecoveryMode::DEFAULT) { + error = transaction.WriteToWAL(context, db, commit_state); + } // after we finish writing to the WAL we grab the transaction lock again t_lock.lock(); @@ -276,18 +302,27 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran // in-memory databases don't have a WAL - we estimate how large their changeset is based on the undo properties if (!db.IsSystem()) { auto &storage_manager = db.GetStorageManager(); - if (storage_manager.InMemory()) { - storage_manager.AddInMemoryChange(undo_properties.estimated_size); + if (storage_manager.InMemory() || db.GetRecoveryMode() == RecoveryMode::NO_WAL_WRITES) { + storage_manager.AddWALSize(undo_properties.estimated_size); } } // obtain a commit id for the transaction - transaction_t commit_id = GetCommitTimestamp(); + CommitInfo info; + info.commit_id = GetCommitTimestamp(); + // commit the UndoBuffer of the transaction if (!error.HasError()) { - error = transaction.Commit(db, commit_id, std::move(commit_state)); + if (HasOtherTransactions(transaction)) { + info.active_transactions = ActiveTransactionState::OTHER_TRANSACTIONS; + } else { + info.active_transactions = ActiveTransactionState::NO_OTHER_TRANSACTIONS; + } + error = transaction.Commit(db, info, std::move(commit_state)); } if (error.HasError()) { + DUCKDB_LOG(context, TransactionLogType, db, "Rollback (after failed commit)", info.commit_id); + // COMMIT not successful: ROLLBACK. checkpoint_decision = CheckpointDecision(error.Message()); transaction.commit_id = 0; @@ -299,6 +334,9 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran error.Message(), rollback_error.Message()); } } else { + DUCKDB_LOG(context, TransactionLogType, db, "Commit", info.commit_id); + last_commit = info.commit_id; + // check if catalog changes were made if (transaction.catalog_version >= TRANSACTION_ID_START) { transaction.catalog_version = ++last_committed_version; @@ -324,8 +362,9 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran } // We do not need to hold the transaction lock during cleanup of transactions, - // as they (1) have been removed, or (2) exited old_transactions. + // as they (1) have been removed, or (2) enter cleanup_info. t_lock.unlock(); + held_wal_lock.reset(); { lock_guard c_lock(cleanup_lock); @@ -353,7 +392,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran options.type = checkpoint_decision.type; auto &storage_manager = db.GetStorageManager(); try { - storage_manager.CreateCheckpoint(QueryContext(context), options); + storage_manager.CreateCheckpoint(context, options); } catch (std::exception &ex) { error.Merge(ErrorData(ex)); } @@ -365,6 +404,8 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran void DuckTransactionManager::RollbackTransaction(Transaction &transaction_p) { auto &transaction = transaction_p.Cast(); + DUCKDB_LOG(db.GetDatabase(), TransactionLogType, db, "Rollback", transaction.transaction_id); + ErrorData error; { // Obtain the transaction lock and roll back. @@ -412,7 +453,7 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa idx_t t_index = active_transactions.size(); auto lowest_start_time = TRANSACTION_ID_START; auto lowest_transaction_id = MAX_TRANSACTION_ID; - auto lowest_active_query = MAXIMUM_QUERY_ID; + auto active_checkpoint_id = active_checkpoint.load(); for (idx_t i = 0; i < active_transactions.size(); i++) { if (active_transactions[i].get() == &transaction) { t_index = i; @@ -420,8 +461,9 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa } lowest_start_time = MinValue(lowest_start_time, active_transactions[i]->start_time); lowest_transaction_id = MinValue(lowest_transaction_id, active_transactions[i]->transaction_id); - transaction_t active_query = active_transactions[i]->active_query; - lowest_active_query = MinValue(lowest_active_query, active_query); + } + if (active_checkpoint_id != MAX_TRANSACTION_ID && active_checkpoint_id < lowest_start_time) { + lowest_start_time = active_checkpoint_id; } lowest_active_start = lowest_start_time; lowest_active_id = lowest_transaction_id; @@ -429,7 +471,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa // Decide if we need to store the transaction, or if we can schedule it for cleanup. auto current_transaction = std::move(active_transactions[t_index]); - auto current_query = DatabaseManager::Get(db).ActiveQueryNumber(); if (store_transaction) { // If the transaction made any changes, we need to keep it around. if (transaction.commit_id != 0) { @@ -438,9 +479,7 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.push_back(std::move(current_transaction)); } else { // The transaction was aborted. - // We might still need its information; add it to the set of transactions awaiting GC. - current_transaction->highest_active_query = current_query; - old_transactions.push_back(std::move(current_transaction)); + cleanup_info->transactions.push_back(std::move(current_transaction)); } } else if (transaction.ChangesMade()) { // We do not need to store the transaction, directly schedule it for cleanup. @@ -464,18 +503,8 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa break; } - // Changes made BEFORE this transaction are no longer relevant. - // We can schedule the transaction and its undo buffer for cleanup. recently_committed_transactions[i]->awaiting_cleanup = true; - - // HOWEVER: Any currently running QUERY can still be using - // the version information of the transaction. - // If we remove the UndoBuffer immediately, we have a race condition. - - // Store the current highest active query. - recently_committed_transactions[i]->highest_active_query = current_query; - // Move it to the list of transactions awaiting GC. - old_transactions.push_back(std::move(recently_committed_transactions[i])); + cleanup_info->transactions.push_back(std::move(recently_committed_transactions[i])); } if (i > 0) { @@ -485,34 +514,6 @@ unique_ptr DuckTransactionManager::RemoveTransaction(DuckTransa recently_committed_transactions.erase(start, end); } - // Check if we can clean up and free the memory of any old transactions. - i = active_transactions.empty() ? old_transactions.size() : 0; - for (; i < old_transactions.size(); i++) { - D_ASSERT(old_transactions[i]); - D_ASSERT(old_transactions[i]->highest_active_query > 0); - if (old_transactions[i]->highest_active_query >= lowest_active_query) { - // There is still a query running that could be using - // this transactions' data. - break; - } - } - - if (i > 0) { - // We garbage-collected old transactions: - // - Remove them from the list and schedule them for cleanup. - - // We can only safely do the actual memory cleanup when all the - // currently active queries have finished running! (actually, - // when all the currently active scans have finished running...). - - // Because we clean up asynchronously, we only clean up once we - // no longer need the transaction for anything (i.e., we can move it). - for (idx_t t_idx = 0; t_idx < i; t_idx++) { - cleanup_info->transactions.push_back(std::move(old_transactions[t_idx])); - } - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + static_cast(i)); - } - return cleanup_info; } diff --git a/src/duckdb/src/transaction/meta_transaction.cpp b/src/duckdb/src/transaction/meta_transaction.cpp index 6fee5d96b..02ebe8704 100644 --- a/src/duckdb/src/transaction/meta_transaction.cpp +++ b/src/duckdb/src/transaction/meta_transaction.cpp @@ -226,7 +226,7 @@ AttachedDatabase &MetaTransaction::UseDatabase(shared_ptr &dat return db_ref; } -void MetaTransaction::ModifyDatabase(AttachedDatabase &db) { +void MetaTransaction::ModifyDatabase(AttachedDatabase &db, DatabaseModificationType modification) { if (IsReadOnly()) { throw TransactionException("Cannot write to database \"%s\" - transaction is launched in read-only mode", db.GetName()); @@ -235,6 +235,7 @@ void MetaTransaction::ModifyDatabase(AttachedDatabase &db) { if (transaction.IsReadOnly()) { transaction.SetReadWrite(); } + transaction.SetModifications(modification); if (db.IsSystem() || db.IsTemporary()) { // we can always modify the system and temp databases return; diff --git a/src/duckdb/src/transaction/transaction.cpp b/src/duckdb/src/transaction/transaction.cpp index 1f18a6d56..2cc292d83 100644 --- a/src/duckdb/src/transaction/transaction.cpp +++ b/src/duckdb/src/transaction/transaction.cpp @@ -21,4 +21,7 @@ void Transaction::SetReadWrite() { is_read_only = false; } +void Transaction::SetModifications(DatabaseModificationType type) { +} + } // namespace duckdb diff --git a/src/duckdb/src/transaction/transaction_context.cpp b/src/duckdb/src/transaction/transaction_context.cpp index f6958b899..deaae2029 100644 --- a/src/duckdb/src/transaction/transaction_context.cpp +++ b/src/duckdb/src/transaction/transaction_context.cpp @@ -17,6 +17,12 @@ TransactionContext::~TransactionContext() { if (current_transaction) { try { Rollback(nullptr); + } catch (std::exception &ex) { + ErrorData data(ex); + try { + DUCKDB_LOG_ERROR(context, "TransactionContext::~TransactionContext()\t\t" + data.Message()); + } catch (...) { // NOLINT + } } catch (...) { // NOLINT } } diff --git a/src/duckdb/src/transaction/undo_buffer.cpp b/src/duckdb/src/transaction/undo_buffer.cpp index 4408a972b..8adb8e2de 100644 --- a/src/duckdb/src/transaction/undo_buffer.cpp +++ b/src/duckdb/src/transaction/undo_buffer.cpp @@ -15,6 +15,7 @@ #include "duckdb/transaction/delete_info.hpp" #include "duckdb/transaction/rollback_state.hpp" #include "duckdb/transaction/wal_write_state.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { constexpr uint32_t UNDO_ENTRY_HEADER_SIZE = sizeof(UndoFlags) + sizeof(uint32_t); @@ -40,6 +41,7 @@ template void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, T &&callback) { // iterate in insertion order: start with the tail state.current = allocator.tail.get(); + state.started = true; while (state.current) { state.handle = allocator.buffer_manager.Pin(state.current->block); state.start = state.handle.Ptr(); @@ -59,6 +61,9 @@ void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, T &&callback) template void UndoBuffer::IterateEntries(UndoBuffer::IteratorState &state, UndoBuffer::IteratorState &end_state, T &&callback) { + if (!end_state.started) { + return; + } // iterate in insertion order: start with the tail state.current = allocator.tail.get(); while (state.current) { @@ -176,16 +181,9 @@ void UndoBuffer::Cleanup(transaction_t lowest_active_transaction) { // the chunks) // (2) there is no active transaction with start_id < commit_id of this // transaction - CleanupState state(lowest_active_transaction); + CleanupState state(QueryContext(), lowest_active_transaction, active_transaction_state); UndoBuffer::IteratorState iterator_state; IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); - -#ifdef DEBUG - // Verify that our index memory is stable. - for (auto &table : state.indexed_tables) { - table.second->VerifyIndexBuffers(); - } -#endif } void UndoBuffer::WriteToWAL(WriteAheadLog &wal, optional_ptr commit_state) { @@ -194,15 +192,21 @@ void UndoBuffer::WriteToWAL(WriteAheadLog &wal, optional_ptr IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); } -void UndoBuffer::Commit(UndoBuffer::IteratorState &iterator_state, transaction_t commit_id) { - CommitState state(transaction, commit_id); +void UndoBuffer::Commit(UndoBuffer::IteratorState &iterator_state, CommitInfo &info) { + active_transaction_state = info.active_transactions; + + CommitState state(transaction, info.commit_id, active_transaction_state, CommitMode::COMMIT); IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CommitEntry(type, data); }); + + state.Verify(); } void UndoBuffer::RevertCommit(UndoBuffer::IteratorState &end_state, transaction_t transaction_id) { - CommitState state(transaction, transaction_id); + CommitState state(transaction, transaction_id, active_transaction_state, CommitMode::REVERT_COMMIT); UndoBuffer::IteratorState start_state; IterateEntries(start_state, end_state, [&](UndoFlags type, data_ptr_t data) { state.RevertCommit(type, data); }); + + state.Verify(); } void UndoBuffer::Rollback() { diff --git a/src/duckdb/src/transaction/wal_write_state.cpp b/src/duckdb/src/transaction/wal_write_state.cpp index 5fe17e050..c0671eec0 100644 --- a/src/duckdb/src/transaction/wal_write_state.cpp +++ b/src/duckdb/src/transaction/wal_write_state.cpp @@ -27,10 +27,10 @@ WALWriteState::WALWriteState(DuckTransaction &transaction_p, WriteAheadLog &log, : transaction(transaction_p), log(log), commit_state(commit_state), current_table_info(nullptr) { } -void WALWriteState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { - if (current_table_info != table_info) { +void WALWriteState::SwitchTable(DataTableInfo &table_info, UndoFlags new_op) { + if (current_table_info != &table_info) { // write the current table to the log - log.WriteSetTable(table_info->GetSchemaName(), table_info->GetTableName()); + log.WriteSetTable(table_info.GetSchemaName(), table_info.GetTableName()); current_table_info = table_info; } } @@ -171,7 +171,7 @@ void WALWriteState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { void WALWriteState::WriteDelete(DeleteInfo &info) { // switch to the current table, if necessary - SwitchTable(info.table->GetDataTableInfo().get(), UndoFlags::DELETE_TUPLE); + SwitchTable(*info.table->GetDataTableInfo(), UndoFlags::DELETE_TUPLE); if (!delete_chunk) { delete_chunk = make_uniq(); @@ -198,7 +198,7 @@ void WALWriteState::WriteUpdate(UpdateInfo &info) { auto &column_data = info.segment->column_data; auto &table_info = column_data.GetTableInfo(); - SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); + SwitchTable(table_info, UndoFlags::UPDATE_TUPLE); // initialize the update chunk vector update_types; @@ -217,7 +217,7 @@ void WALWriteState::WriteUpdate(UpdateInfo &info) { // write the row ids into the chunk auto row_ids = FlatVector::GetData(update_chunk->data[1]); - idx_t start = column_data.start + info.vector_index * STANDARD_VECTOR_SIZE; + idx_t start = info.row_group_start + info.vector_index * STANDARD_VECTOR_SIZE; auto tuples = info.GetTuples(); for (idx_t i = 0; i < info.N; i++) { row_ids[tuples[i]] = UnsafeNumericCast(start + tuples[i]); diff --git a/src/duckdb/src/verification/deserialized_statement_verifier.cpp b/src/duckdb/src/verification/deserialized_statement_verifier.cpp index 1ade815d7..3d72d6159 100644 --- a/src/duckdb/src/verification/deserialized_statement_verifier.cpp +++ b/src/duckdb/src/verification/deserialized_statement_verifier.cpp @@ -13,7 +13,6 @@ DeserializedStatementVerifier::DeserializedStatementVerifier( unique_ptr DeserializedStatementVerifier::Create(const SQLStatement &statement, optional_ptr> parameters) { - auto &select_stmt = statement.Cast(); Allocator allocator; MemoryStream stream(allocator); diff --git a/src/duckdb/src/verification/statement_verifier.cpp b/src/duckdb/src/verification/statement_verifier.cpp index 81f4c4aba..14e4c0491 100644 --- a/src/duckdb/src/verification/statement_verifier.cpp +++ b/src/duckdb/src/verification/statement_verifier.cpp @@ -1,5 +1,9 @@ #include "duckdb/verification/statement_verifier.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + #include "duckdb/common/error_data.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/parser/parser.hpp" @@ -15,13 +19,24 @@ namespace duckdb { +const vector> &StatementVerifier::GetSelectList(QueryNode &node) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: + return node.Cast().select_list; + case QueryNodeType::SET_OPERATION_NODE: + return GetSelectList(*node.Cast().children[0]); + default: + return empty_select_list; + } +} + StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p, optional_ptr> parameters_p) : type(type), name(std::move(name)), statement(std::move(statement_p)), select_statement(statement->type == StatementType::SELECT_STATEMENT ? &statement->Cast() : nullptr), parameters(parameters_p), - select_list(select_statement ? select_statement->node->GetSelectList() : empty_select_list) { + select_list(select_statement ? GetSelectList(*select_statement->node) : empty_select_list) { } StatementVerifier::StatementVerifier(unique_ptr statement_p, diff --git a/src/duckdb/third_party/httplib/httplib.hpp b/src/duckdb/third_party/httplib/httplib.hpp index 4aa0458dc..409c47d0b 100644 --- a/src/duckdb/third_party/httplib/httplib.hpp +++ b/src/duckdb/third_party/httplib/httplib.hpp @@ -7077,7 +7077,12 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { } auto location = res.get_header_value("location"); - if (location.empty()) { return false; } + if (location.empty()) { + // s3 requests will not return a location header, and instead a + // X-Amx-Region-Bucket header. Return true so all response headers + // are returned to the httpfs/calling extension + return true; + } const Regex re( R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); diff --git a/src/duckdb/third_party/libpg_query/include/pg_functions.hpp b/src/duckdb/third_party/libpg_query/include/pg_functions.hpp index bb591f75d..f33723183 100644 --- a/src/duckdb/third_party/libpg_query/include/pg_functions.hpp +++ b/src/duckdb/third_party/libpg_query/include/pg_functions.hpp @@ -3,7 +3,9 @@ #include #include +#ifndef __MVS__ #define fprintf(...) +#endif #include "pg_definitions.hpp" diff --git a/src/duckdb/third_party/libpg_query/pg_functions.cpp b/src/duckdb/third_party/libpg_query/pg_functions.cpp index 3b7a7515e..36bed9dcb 100644 --- a/src/duckdb/third_party/libpg_query/pg_functions.cpp +++ b/src/duckdb/third_party/libpg_query/pg_functions.cpp @@ -30,13 +30,8 @@ struct pg_parser_state_str { }; #ifdef __MVS__ -// -------------------------------------------------------- -// Permanent - WIP -// static __tlssim pg_parser_state_impl(); -// #define pg_parser_state (*pg_parser_state_impl.access()) -// -------------------------------------------------------- -// Temporary -static parser_state pg_parser_state; +static __tlssim pg_parser_state_impl; +#define pg_parser_state (*pg_parser_state_impl.access()) #else static __thread parser_state pg_parser_state; #endif diff --git a/src/duckdb/third_party/mbedtls/include/mbedtls_wrapper.hpp b/src/duckdb/third_party/mbedtls/include/mbedtls_wrapper.hpp index d9f8111d8..e6e97a717 100644 --- a/src/duckdb/third_party/mbedtls/include/mbedtls_wrapper.hpp +++ b/src/duckdb/third_party/mbedtls/include/mbedtls_wrapper.hpp @@ -81,6 +81,7 @@ class AESStateMBEDTLS : public duckdb::EncryptionState { DUCKDB_API void GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) override; DUCKDB_API void FinalizeGCM(duckdb::data_ptr_t tag, duckdb::idx_t tag_len); DUCKDB_API const mbedtls_cipher_info_t *GetCipher(size_t key_len); + DUCKDB_API static void SecureClearData(duckdb::data_ptr_t data, duckdb::idx_t len); private: DUCKDB_API void InitializeInternal(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len); @@ -98,6 +99,10 @@ class AESStateMBEDTLS : public duckdb::EncryptionState { } ~AESStateMBEDTLSFactory() override {} // + + DUCKDB_API bool SupportsEncryption() override { + return false; + } }; }; diff --git a/src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp b/src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp index 7dc0af7fd..3a6ce981e 100644 --- a/src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp +++ b/src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp @@ -271,6 +271,10 @@ const mbedtls_cipher_info_t *MbedTlsWrapper::AESStateMBEDTLS::GetCipher(size_t k } } +void MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(duckdb::data_ptr_t data, duckdb::idx_t len) { + mbedtls_platform_zeroize(data, len); +} + MbedTlsWrapper::AESStateMBEDTLS::AESStateMBEDTLS(duckdb::EncryptionTypes::CipherType cipher_p, duckdb::idx_t key_len) : EncryptionState(cipher_p, key_len), context(duckdb::make_uniq()) { mbedtls_cipher_init(context.get()); @@ -296,20 +300,12 @@ MbedTlsWrapper::AESStateMBEDTLS::~AESStateMBEDTLS() { } } -void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(duckdb::data_ptr_t data, duckdb::idx_t len) { - duckdb::RandomEngine random_engine; - - while (len) { - const auto random_integer = random_engine.NextRandomInteger(); - const auto next = duckdb::MinValue(len, sizeof(random_integer)); - memcpy(data, duckdb::const_data_ptr_cast(&random_integer), next); - data += next; - len -= next; - } +static void ThrowInsecureRNG() { + throw duckdb::InvalidConfigurationException("DuckDB requires a secure random engine to be loaded to enable secure crypto. Normally, this will be handled automatically by DuckDB by autoloading the `httpfs` Extension, but that seems to have failed. Please ensure the httpfs extension is loaded manually using `LOAD httpfs`."); } void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) { - GenerateRandomDataStatic(data, len); + ThrowInsecureRNG(); } void MbedTlsWrapper::AESStateMBEDTLS::InitializeInternal(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len){ @@ -325,16 +321,7 @@ void MbedTlsWrapper::AESStateMBEDTLS::InitializeInternal(duckdb::const_data_ptr_ } void MbedTlsWrapper::AESStateMBEDTLS::InitializeEncryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t key, duckdb::idx_t key_len_p, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) { - mode = duckdb::EncryptionTypes::ENCRYPT; - - if (key_len_p != key_len) { - throw duckdb::InternalException("Invalid encryption key length, expected %llu, got %llu", key_len, key_len_p); - } - if (mbedtls_cipher_setkey(context.get(), key, key_len * 8, MBEDTLS_ENCRYPT)) { - throw runtime_error("Failed to set AES key for encryption"); - } - - InitializeInternal(iv, iv_len, aad, aad_len); + ThrowInsecureRNG(); } void MbedTlsWrapper::AESStateMBEDTLS::InitializeDecryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t key, duckdb::idx_t key_len_p, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) { diff --git a/src/duckdb/third_party/parquet/parquet_types.cpp b/src/duckdb/third_party/parquet/parquet_types.cpp index 95cfbc3f7..a508a69f2 100644 --- a/src/duckdb/third_party/parquet/parquet_types.cpp +++ b/src/duckdb/third_party/parquet/parquet_types.cpp @@ -1,5 +1,5 @@ /** - * Autogenerated by Thrift Compiler (0.21.0) + * Autogenerated by Thrift Compiler (0.22.0) * * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING * @generated @@ -13,6 +13,14 @@ namespace duckdb_parquet { +template +static typename ENUM::type SafeEnumCast(const std::map &values_to_names, const int &ecast) { + if (values_to_names.find(ecast) == values_to_names.end()) { + throw duckdb_apache::thrift::protocol::TProtocolException(duckdb_apache::thrift::protocol::TProtocolException::INVALID_DATA); + } + return static_cast(ecast); +} + int _kTypeValues[] = { Type::BOOLEAN, Type::INT32, @@ -176,7 +184,14 @@ int _kConvertedTypeValues[] = { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - ConvertedType::INTERVAL + ConvertedType::INTERVAL, + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + ConvertedType::PARQUET_NULL }; const char* _kConvertedTypeNames[] = { /** @@ -300,9 +315,16 @@ const char* _kConvertedTypeNames[] = { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - "INTERVAL" + "INTERVAL", + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + "PARQUET_NULL" }; -const std::map _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(22, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, nullptr, nullptr)); +const std::map _ConvertedType_VALUES_TO_NAMES(::apache::thrift::TEnumIterator(23, _kConvertedTypeValues, _kConvertedTypeNames), ::apache::thrift::TEnumIterator(-1, nullptr, nullptr)); std::ostream& operator<<(std::ostream& out, const ConvertedType::type& val) { std::map::const_iterator it = _ConvertedType_VALUES_TO_NAMES.find(val); @@ -3446,7 +3468,7 @@ GeographyType::~GeographyType() noexcept { GeographyType::GeographyType() noexcept : crs(), - algorithm(static_cast(0)) { + algorithm(SafeEnumCast(_EdgeInterpolationAlgorithm_VALUES_TO_NAMES, 0)) { } void GeographyType::__set_crs(const std::string& val) { @@ -3498,7 +3520,7 @@ uint32_t GeographyType::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast114; xfer += iprot->readI32(ecast114); - this->algorithm = static_cast(ecast114); + this->algorithm = SafeEnumCast(_EdgeInterpolationAlgorithm_VALUES_TO_NAMES, ecast114); this->__isset.algorithm = true; } else { xfer += iprot->skip(ftype); @@ -4067,12 +4089,12 @@ SchemaElement::~SchemaElement() noexcept { } SchemaElement::SchemaElement() noexcept - : type(static_cast(0)), + : type(SafeEnumCast(_Type_VALUES_TO_NAMES, 0)), type_length(0), - repetition_type(static_cast(0)), + repetition_type(SafeEnumCast(_FieldRepetitionType_VALUES_TO_NAMES, 0)), name(), num_children(0), - converted_type(static_cast(0)), + converted_type(SafeEnumCast(_ConvertedType_VALUES_TO_NAMES, 0)), scale(0), precision(0), field_id(0) { @@ -4159,7 +4181,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast123; xfer += iprot->readI32(ecast123); - this->type = static_cast(ecast123); + this->type = SafeEnumCast(_Type_VALUES_TO_NAMES, ecast123); this->__isset.type = true; } else { xfer += iprot->skip(ftype); @@ -4177,7 +4199,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast124; xfer += iprot->readI32(ecast124); - this->repetition_type = static_cast(ecast124); + this->repetition_type = SafeEnumCast(_FieldRepetitionType_VALUES_TO_NAMES, ecast124); this->__isset.repetition_type = true; } else { xfer += iprot->skip(ftype); @@ -4203,7 +4225,7 @@ uint32_t SchemaElement::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast125; xfer += iprot->readI32(ecast125); - this->converted_type = static_cast(ecast125); + this->converted_type = SafeEnumCast(_ConvertedType_VALUES_TO_NAMES, ecast125); this->__isset.converted_type = true; } else { xfer += iprot->skip(ftype); @@ -4405,9 +4427,9 @@ DataPageHeader::~DataPageHeader() noexcept { DataPageHeader::DataPageHeader() noexcept : num_values(0), - encoding(static_cast(0)), - definition_level_encoding(static_cast(0)), - repetition_level_encoding(static_cast(0)) { + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), + definition_level_encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), + repetition_level_encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)) { } void DataPageHeader::__set_num_values(const int32_t val) { @@ -4474,7 +4496,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast130; xfer += iprot->readI32(ecast130); - this->encoding = static_cast(ecast130); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast130); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4484,7 +4506,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast131; xfer += iprot->readI32(ecast131); - this->definition_level_encoding = static_cast(ecast131); + this->definition_level_encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast131); isset_definition_level_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4494,7 +4516,7 @@ uint32_t DataPageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast132; xfer += iprot->readI32(ecast132); - this->repetition_level_encoding = static_cast(ecast132); + this->repetition_level_encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast132); isset_repetition_level_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4697,7 +4719,7 @@ DictionaryPageHeader::~DictionaryPageHeader() noexcept { DictionaryPageHeader::DictionaryPageHeader() noexcept : num_values(0), - encoding(static_cast(0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), is_sorted(0) { } @@ -4755,7 +4777,7 @@ uint32_t DictionaryPageHeader::read(::apache::thrift::protocol::TProtocol* iprot if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast141; xfer += iprot->readI32(ecast141); - this->encoding = static_cast(ecast141); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast141); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -4859,7 +4881,7 @@ DataPageHeaderV2::DataPageHeaderV2() noexcept : num_values(0), num_nulls(0), num_rows(0), - encoding(static_cast(0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), definition_levels_byte_length(0), repetition_levels_byte_length(0), is_compressed(true) { @@ -4960,7 +4982,7 @@ uint32_t DataPageHeaderV2::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast146; xfer += iprot->readI32(ecast146); - this->encoding = static_cast(ecast146); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast146); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -5867,7 +5889,7 @@ PageHeader::~PageHeader() noexcept { } PageHeader::PageHeader() noexcept - : type(static_cast(0)), + : type(SafeEnumCast(_PageType_VALUES_TO_NAMES, 0)), uncompressed_page_size(0), compressed_page_size(0), crc(0) { @@ -5944,7 +5966,7 @@ uint32_t PageHeader::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast179; xfer += iprot->readI32(ecast179); - this->type = static_cast(ecast179); + this->type = SafeEnumCast(_PageType_VALUES_TO_NAMES, ecast179); isset_type = true; } else { xfer += iprot->skip(ftype); @@ -6435,8 +6457,8 @@ PageEncodingStats::~PageEncodingStats() noexcept { } PageEncodingStats::PageEncodingStats() noexcept - : page_type(static_cast(0)), - encoding(static_cast(0)), + : page_type(SafeEnumCast(_PageType_VALUES_TO_NAMES, 0)), + encoding(SafeEnumCast(_Encoding_VALUES_TO_NAMES, 0)), count(0) { } @@ -6486,7 +6508,7 @@ uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast192; xfer += iprot->readI32(ecast192); - this->page_type = static_cast(ecast192); + this->page_type = SafeEnumCast(_PageType_VALUES_TO_NAMES, ecast192); isset_page_type = true; } else { xfer += iprot->skip(ftype); @@ -6496,7 +6518,7 @@ uint32_t PageEncodingStats::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast193; xfer += iprot->readI32(ecast193); - this->encoding = static_cast(ecast193); + this->encoding = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast193); isset_encoding = true; } else { xfer += iprot->skip(ftype); @@ -6593,8 +6615,8 @@ ColumnMetaData::~ColumnMetaData() noexcept { } ColumnMetaData::ColumnMetaData() noexcept - : type(static_cast(0)), - codec(static_cast(0)), + : type(SafeEnumCast(_Type_VALUES_TO_NAMES, 0)), + codec(SafeEnumCast(_CompressionCodec_VALUES_TO_NAMES, 0)), num_values(0), total_uncompressed_size(0), total_compressed_size(0), @@ -6721,7 +6743,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast198; xfer += iprot->readI32(ecast198); - this->type = static_cast(ecast198); + this->type = SafeEnumCast(_Type_VALUES_TO_NAMES, ecast198); isset_type = true; } else { xfer += iprot->skip(ftype); @@ -6740,7 +6762,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { { int32_t ecast204; xfer += iprot->readI32(ecast204); - this->encodings[_i203] = static_cast(ecast204); + this->encodings[_i203] = SafeEnumCast(_Encoding_VALUES_TO_NAMES, ecast204); } xfer += iprot->readListEnd(); } @@ -6773,7 +6795,7 @@ uint32_t ColumnMetaData::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast210; xfer += iprot->readI32(ecast210); - this->codec = static_cast(ecast210); + this->codec = SafeEnumCast(_CompressionCodec_VALUES_TO_NAMES, ecast210); isset_codec = true; } else { xfer += iprot->skip(ftype); @@ -8651,7 +8673,7 @@ ColumnIndex::~ColumnIndex() noexcept { } ColumnIndex::ColumnIndex() noexcept - : boundary_order(static_cast(0)) { + : boundary_order(SafeEnumCast(_BoundaryOrder_VALUES_TO_NAMES, 0)) { } void ColumnIndex::__set_null_pages(const duckdb::vector & val) { @@ -8780,7 +8802,7 @@ uint32_t ColumnIndex::read(::apache::thrift::protocol::TProtocol* iprot) { if (ftype == ::apache::thrift::protocol::T_I32) { int32_t ecast310; xfer += iprot->readI32(ecast310); - this->boundary_order = static_cast(ecast310); + this->boundary_order = SafeEnumCast(_BoundaryOrder_VALUES_TO_NAMES, ecast310); isset_boundary_order = true; } else { xfer += iprot->skip(ftype); diff --git a/src/duckdb/third_party/parquet/parquet_types.h b/src/duckdb/third_party/parquet/parquet_types.h index a872a3d6b..762d3533a 100644 --- a/src/duckdb/third_party/parquet/parquet_types.h +++ b/src/duckdb/third_party/parquet/parquet_types.h @@ -1,5 +1,5 @@ /** - * Autogenerated by Thrift Compiler (0.21.0) + * Autogenerated by Thrift Compiler (0.22.0) * * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING * @generated @@ -178,7 +178,14 @@ struct ConvertedType { * the provided duration. This duration of time is independent of any * particular timezone or date. */ - INTERVAL = 21 + INTERVAL = 21, + /** + * Non-standard NULL value + * + * This was written by old writers - it is kept here for compatibility purposes. + * See https://github.com/duckdb/duckdb/pull/11774 + */ + PARQUET_NULL = 24 }; }; diff --git a/src/duckdb/third_party/re2/re2/re2.h b/src/duckdb/third_party/re2/re2/re2.h index f34936011..538594a2c 100644 --- a/src/duckdb/third_party/re2/re2/re2.h +++ b/src/duckdb/third_party/re2/re2/re2.h @@ -985,7 +985,7 @@ namespace hooks { // As per https://github.com/google/re2/issues/325, thread_local support in // MinGW seems to be buggy. (FWIW, Abseil folks also avoid it.) #define RE2_HAVE_THREAD_LOCAL -#if (defined(__APPLE__) && !(defined(TARGET_OS_OSX) && TARGET_OS_OSX)) || defined(__MINGW32__) +#if (defined(__APPLE__) && !(defined(TARGET_OS_OSX) && TARGET_OS_OSX)) || defined(__MINGW32__) || defined(__MVS__) #undef RE2_HAVE_THREAD_LOCAL #endif diff --git a/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp b/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp index ea8e28690..8edab2a1a 100644 --- a/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp +++ b/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp @@ -29,6 +29,8 @@ class Utf8Proc { static bool IsValid(const char *s, size_t len); //! Makes Invalid Unicode valid by replacing invalid parts with a given character static void MakeValid(char *s, size_t len, char special_flag = '?'); + //! Creates a new string with invalid UTF-8 characters removed + static std::string RemoveInvalid(const char *s, size_t len); //! Returns the position (in bytes) of the next grapheme cluster static size_t NextGraphemeCluster(const char *s, size_t len, size_t pos); //! Returns the position (in bytes) of the previous grapheme cluster diff --git a/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp b/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp index 1e9fcae89..7ee4f39c2 100644 --- a/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp +++ b/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp @@ -162,6 +162,46 @@ bool Utf8Proc::IsValid(const char *s, size_t len) { return Utf8Proc::Analyze(s, len) != UnicodeType::INVALID; } +std::string Utf8Proc::RemoveInvalid(const char *s, size_t len) { + std::string result; + result.reserve(len); // Reserve the maximum possible size + + for (size_t i = 0; i < len; i++) { + int c = (int)s[i]; + if ((c & 0x80) == 0) { + // ASCII character - always valid + result.push_back(s[i]); + continue; + } + + int first_pos_seq = i; + if ((c & 0xE0) == 0xC0) { + /* 2 byte sequence */ + int utf8char = c & 0x1F; + UTF8ExtraByteLoop<1, 0x000780>(first_pos_seq, utf8char, i, s, len, nullptr, nullptr); + } else if ((c & 0xF0) == 0xE0) { + /* 3 byte sequence */ + int utf8char = c & 0x0F; + UTF8ExtraByteLoop<2, 0x00F800>(first_pos_seq, utf8char, i, s, len, nullptr, nullptr); + } else if ((c & 0xF8) == 0xF0) { + /* 4 byte sequence */ + int utf8char = c & 0x07; + UTF8ExtraByteLoop<3, 0x1F0000>(first_pos_seq, utf8char, i, s, len, nullptr, nullptr); + } else { + // invalid, do not write to output + continue; + } + + // If we get here, the sequence is valid, so add all bytes of the sequence to result + for (size_t j = first_pos_seq; j <= i; j++) { + result.push_back(s[j]); + } + } + + D_ASSERT(Utf8Proc::IsValid(result.c_str(), result.size())); + return result; +} + size_t Utf8Proc::NextGraphemeCluster(const char *s, size_t len, size_t cpos) { int sz; auto prev_codepoint = Utf8Proc::UTF8ToCodepoint(s + cpos, sz); diff --git a/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp new file mode 100644 index 000000000..a848d44a9 --- /dev/null +++ b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// yyjson_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +struct ConvertedJSONHolder { +public: + ~ConvertedJSONHolder() { + if (doc) { + yyjson_mut_doc_free(doc); + } + if (stringified_json) { + free(stringified_json); + } + } + +public: + yyjson_mut_doc *doc = nullptr; + char *stringified_json = nullptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/ub_extension_core_functions_scalar_bit.cpp b/src/duckdb/ub_extension_core_functions_scalar_bit.cpp deleted file mode 100644 index 0e48db861..000000000 --- a/src/duckdb/ub_extension_core_functions_scalar_bit.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "extension/core_functions/scalar/bit/bitstring.cpp" - diff --git a/src/duckdb/ub_extension_core_functions_scalar_debug.cpp b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp index f1c3fa82e..d0822fdfe 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_debug.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp @@ -1,2 +1,4 @@ #include "extension/core_functions/scalar/debug/vector_type.cpp" +#include "extension/core_functions/scalar/debug/sleep.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_enum.cpp b/src/duckdb/ub_extension_core_functions_scalar_enum.cpp deleted file mode 100644 index 74e9bf3f7..000000000 --- a/src/duckdb/ub_extension_core_functions_scalar_enum.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "extension/core_functions/scalar/enum/enum_functions.cpp" - diff --git a/src/duckdb/ub_extension_core_functions_scalar_math.cpp b/src/duckdb/ub_extension_core_functions_scalar_math.cpp deleted file mode 100644 index 27320ea9f..000000000 --- a/src/duckdb/ub_extension_core_functions_scalar_math.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "extension/core_functions/scalar/math/numeric.cpp" - diff --git a/src/duckdb/ub_extension_core_functions_scalar_operators.cpp b/src/duckdb/ub_extension_core_functions_scalar_operators.cpp deleted file mode 100644 index 47383d4eb..000000000 --- a/src/duckdb/ub_extension_core_functions_scalar_operators.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "extension/core_functions/scalar/operators/bitwise.cpp" - diff --git a/src/duckdb/ub_extension_core_functions_scalar_struct.cpp b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp index 46d3f649f..f7929da2f 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_struct.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp @@ -2,3 +2,7 @@ #include "extension/core_functions/scalar/struct/struct_update.cpp" +#include "extension/core_functions/scalar/struct/struct_keys.cpp" + +#include "extension/core_functions/scalar/struct/struct_values.cpp" + diff --git a/src/duckdb/ub_extension_parquet_reader_variant.cpp b/src/duckdb/ub_extension_parquet_reader_variant.cpp index af6cda649..4e55f0b16 100644 --- a/src/duckdb/ub_extension_parquet_reader_variant.cpp +++ b/src/duckdb/ub_extension_parquet_reader_variant.cpp @@ -1,6 +1,4 @@ #include "extension/parquet/reader/variant/variant_binary_decoder.cpp" -#include "extension/parquet/reader/variant/variant_value.cpp" - #include "extension/parquet/reader/variant/variant_shredded_conversion.cpp" diff --git a/src/duckdb/ub_extension_parquet_writer_variant.cpp b/src/duckdb/ub_extension_parquet_writer_variant.cpp new file mode 100644 index 000000000..6e4563d1e --- /dev/null +++ b/src/duckdb/ub_extension_parquet_writer_variant.cpp @@ -0,0 +1,4 @@ +#include "extension/parquet/writer/variant/convert_variant.cpp" + +#include "extension/parquet/writer/variant/analyze_variant.cpp" + diff --git a/src/duckdb/ub_src_common.cpp b/src/duckdb/ub_src_common.cpp index c51f91c48..f4f18e319 100644 --- a/src/duckdb/ub_src_common.cpp +++ b/src/duckdb/ub_src_common.cpp @@ -66,6 +66,8 @@ #include "src/common/render_tree.cpp" +#include "src/common/thread_util.cpp" + #include "src/common/tree_renderer.cpp" #include "src/common/types.cpp" diff --git a/src/duckdb/ub_src_common_adbc.cpp b/src/duckdb/ub_src_common_adbc.cpp deleted file mode 100644 index 43aee83e9..000000000 --- a/src/duckdb/ub_src_common_adbc.cpp +++ /dev/null @@ -1,4 +0,0 @@ -#include "src/common/adbc/adbc.cpp" - -#include "src/common/adbc/driver_manager.cpp" - diff --git a/src/duckdb/ub_src_common_crypto.cpp b/src/duckdb/ub_src_common_crypto.cpp deleted file mode 100644 index 37f51b193..000000000 --- a/src/duckdb/ub_src_common_crypto.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/common/crypto/md5.cpp" - diff --git a/src/duckdb/ub_src_common_row_operations.cpp b/src/duckdb/ub_src_common_row_operations.cpp index f1ac77f8e..f8f47aee8 100644 --- a/src/duckdb/ub_src_common_row_operations.cpp +++ b/src/duckdb/ub_src_common_row_operations.cpp @@ -1,16 +1,4 @@ #include "src/common/row_operations/row_aggregate.cpp" -#include "src/common/row_operations/row_scatter.cpp" - -#include "src/common/row_operations/row_gather.cpp" - #include "src/common/row_operations/row_matcher.cpp" -#include "src/common/row_operations/row_external.cpp" - -#include "src/common/row_operations/row_radix_scatter.cpp" - -#include "src/common/row_operations/row_heap_scatter.cpp" - -#include "src/common/row_operations/row_heap_gather.cpp" - diff --git a/src/duckdb/ub_src_common_sort.cpp b/src/duckdb/ub_src_common_sort.cpp index e472e71ff..cac9ac0d6 100644 --- a/src/duckdb/ub_src_common_sort.cpp +++ b/src/duckdb/ub_src_common_sort.cpp @@ -1,12 +1,14 @@ -#include "src/common/sort/comparators.cpp" +#include "src/common/sort/full_sort.cpp" -#include "src/common/sort/merge_sorter.cpp" +#include "src/common/sort/hashed_sort.cpp" -#include "src/common/sort/partition_state.cpp" +#include "src/common/sort/natural_sort.cpp" -#include "src/common/sort/radix_sort.cpp" +#include "src/common/sort/sort.cpp" -#include "src/common/sort/sort_state.cpp" +#include "src/common/sort/sort_strategy.cpp" -#include "src/common/sort/sorted_block.cpp" +#include "src/common/sort/sorted_run.cpp" + +#include "src/common/sort/sorted_run_merger.cpp" diff --git a/src/duckdb/ub_src_common_sorting.cpp b/src/duckdb/ub_src_common_sorting.cpp deleted file mode 100644 index b444cb55b..000000000 --- a/src/duckdb/ub_src_common_sorting.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "src/common/sorting/hashed_sort.cpp" - -#include "src/common/sorting/sort.cpp" - -#include "src/common/sorting/sorted_run.cpp" - -#include "src/common/sorting/sorted_run_merger.cpp" - diff --git a/src/duckdb/ub_src_common_tree_renderer.cpp b/src/duckdb/ub_src_common_tree_renderer.cpp index 65e8dfeba..bf7f6001e 100644 --- a/src/duckdb/ub_src_common_tree_renderer.cpp +++ b/src/duckdb/ub_src_common_tree_renderer.cpp @@ -8,5 +8,7 @@ #include "src/common/tree_renderer/yaml_tree_renderer.cpp" +#include "src/common/tree_renderer/mermaid_tree_renderer.cpp" + #include "src/common/tree_renderer/tree_renderer.cpp" diff --git a/src/duckdb/ub_src_common_types.cpp b/src/duckdb/ub_src_common_types.cpp index 7f181227e..5bcfc4f96 100644 --- a/src/duckdb/ub_src_common_types.cpp +++ b/src/duckdb/ub_src_common_types.cpp @@ -54,3 +54,5 @@ #include "src/common/types/vector_constants.cpp" +#include "src/common/types/geometry.cpp" + diff --git a/src/duckdb/ub_src_common_types_row.cpp b/src/duckdb/ub_src_common_types_row.cpp index 3d4ff32c2..b82384bcc 100644 --- a/src/duckdb/ub_src_common_types_row.cpp +++ b/src/duckdb/ub_src_common_types_row.cpp @@ -1,13 +1,5 @@ -#include "src/common/types/row/block_iterator.cpp" - #include "src/common/types/row/partitioned_tuple_data.cpp" -#include "src/common/types/row/row_data_collection.cpp" - -#include "src/common/types/row/row_data_collection_scanner.cpp" - -#include "src/common/types/row/row_layout.cpp" - #include "src/common/types/row/tuple_data_allocator.cpp" #include "src/common/types/row/tuple_data_collection.cpp" diff --git a/src/duckdb/ub_src_common_types_variant.cpp b/src/duckdb/ub_src_common_types_variant.cpp new file mode 100644 index 000000000..f608fdfd1 --- /dev/null +++ b/src/duckdb/ub_src_common_types_variant.cpp @@ -0,0 +1,6 @@ +#include "src/common/types/variant/variant.cpp" + +#include "src/common/types/variant/variant_value.cpp" + +#include "src/common/types/variant/variant_value_convert.cpp" + diff --git a/src/duckdb/ub_src_common_value_operations.cpp b/src/duckdb/ub_src_common_value_operations.cpp deleted file mode 100644 index 429b02ab2..000000000 --- a/src/duckdb/ub_src_common_value_operations.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/common/value_operations/comparison_operations.cpp" - diff --git a/src/duckdb/ub_src_execution_operator_csv_scanner_encode.cpp b/src/duckdb/ub_src_execution_operator_csv_scanner_encode.cpp deleted file mode 100644 index f1db7c41a..000000000 --- a/src/duckdb/ub_src_execution_operator_csv_scanner_encode.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/execution/operator/csv_scanner/encode/csv_encoder.cpp" - diff --git a/src/duckdb/ub_src_execution_operator_filter.cpp b/src/duckdb/ub_src_execution_operator_filter.cpp deleted file mode 100644 index 03631c0cd..000000000 --- a/src/duckdb/ub_src_execution_operator_filter.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/execution/operator/filter/physical_filter.cpp" - diff --git a/src/duckdb/ub_src_function_aggregate.cpp b/src/duckdb/ub_src_function_aggregate.cpp deleted file mode 100644 index 6bfe0ba85..000000000 --- a/src/duckdb/ub_src_function_aggregate.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/aggregate/sorted_aggregate_function.cpp" - diff --git a/src/duckdb/ub_src_function_cast.cpp b/src/duckdb/ub_src_function_cast.cpp index fcf41bbee..99f3378ca 100644 --- a/src/duckdb/ub_src_function_cast.cpp +++ b/src/duckdb/ub_src_function_cast.cpp @@ -12,6 +12,8 @@ #include "src/function/cast/enum_casts.cpp" +#include "src/function/cast/geo_casts.cpp" + #include "src/function/cast/list_casts.cpp" #include "src/function/cast/map_cast.cpp" diff --git a/src/duckdb/ub_src_function_cast_union.cpp b/src/duckdb/ub_src_function_cast_union.cpp deleted file mode 100644 index 016863ea2..000000000 --- a/src/duckdb/ub_src_function_cast_union.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/cast/union/from_struct.cpp" - diff --git a/src/duckdb/ub_src_function_scalar_date.cpp b/src/duckdb/ub_src_function_scalar_date.cpp deleted file mode 100644 index 81e2c26c1..000000000 --- a/src/duckdb/ub_src_function_scalar_date.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/scalar/date/strftime.cpp" - diff --git a/src/duckdb/ub_src_function_scalar_list.cpp b/src/duckdb/ub_src_function_scalar_list.cpp index 70726caaf..c29a5bc3d 100644 --- a/src/duckdb/ub_src_function_scalar_list.cpp +++ b/src/duckdb/ub_src_function_scalar_list.cpp @@ -2,6 +2,8 @@ #include "src/function/scalar/list/list_extract.cpp" +#include "src/function/scalar/list/list_intersect.cpp" + #include "src/function/scalar/list/list_resize.cpp" #include "src/function/scalar/list/list_zip.cpp" diff --git a/src/duckdb/ub_src_function_scalar_map.cpp b/src/duckdb/ub_src_function_scalar_map.cpp deleted file mode 100644 index 0978d7e0e..000000000 --- a/src/duckdb/ub_src_function_scalar_map.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/scalar/map/map_contains.cpp" - diff --git a/src/duckdb/ub_src_function_scalar_sequence.cpp b/src/duckdb/ub_src_function_scalar_sequence.cpp deleted file mode 100644 index 13127fe1d..000000000 --- a/src/duckdb/ub_src_function_scalar_sequence.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/scalar/sequence/nextval.cpp" - diff --git a/src/duckdb/ub_src_function_scalar_variant.cpp b/src/duckdb/ub_src_function_scalar_variant.cpp index a3276cf42..6fd6a062d 100644 --- a/src/duckdb/ub_src_function_scalar_variant.cpp +++ b/src/duckdb/ub_src_function_scalar_variant.cpp @@ -4,3 +4,5 @@ #include "src/function/scalar/variant/variant_typeof.cpp" +#include "src/function/scalar/variant/variant_normalize.cpp" + diff --git a/src/duckdb/ub_src_function_table_system.cpp b/src/duckdb/ub_src_function_table_system.cpp index afa17b21b..5ca818791 100644 --- a/src/duckdb/ub_src_function_table_system.cpp +++ b/src/duckdb/ub_src_function_table_system.cpp @@ -1,3 +1,5 @@ +#include "src/function/table/system/duckdb_connection_count.cpp" + #include "src/function/table/system/duckdb_approx_database_count.cpp" #include "src/function/table/system/duckdb_columns.cpp" diff --git a/src/duckdb/ub_src_function_table_version.cpp b/src/duckdb/ub_src_function_table_version.cpp deleted file mode 100644 index 6131c2a6a..000000000 --- a/src/duckdb/ub_src_function_table_version.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/function/table/version/pragma_version.cpp" - diff --git a/src/duckdb/ub_src_main.cpp b/src/duckdb/ub_src_main.cpp index d3709dc92..b8bc22b97 100644 --- a/src/duckdb/ub_src_main.cpp +++ b/src/duckdb/ub_src_main.cpp @@ -50,12 +50,16 @@ #include "src/main/profiling_info.cpp" +#include "src/main/profiling_utils.cpp" + #include "src/main/relation.cpp" #include "src/main/query_profiler.cpp" #include "src/main/query_result.cpp" +#include "src/main/result_set_manager.cpp" + #include "src/main/stream_query_result.cpp" #include "src/main/valid_checker.cpp" diff --git a/src/duckdb/ub_src_main_capi.cpp b/src/duckdb/ub_src_main_capi.cpp index 30ba6a200..496fc3de6 100644 --- a/src/duckdb/ub_src_main_capi.cpp +++ b/src/duckdb/ub_src_main_capi.cpp @@ -6,8 +6,14 @@ #include "src/main/capi/cast_function-c.cpp" +#include "src/main/capi/catalog-c.cpp" + #include "src/main/capi/config-c.cpp" +#include "src/main/capi/config_options-c.cpp" + +#include "src/main/capi/copy_function-c.cpp" + #include "src/main/capi/data_chunk-c.cpp" #include "src/main/capi/datetime-c.cpp" @@ -26,6 +32,8 @@ #include "src/main/capi/hugeint-c.cpp" +#include "src/main/capi/logging-c.cpp" + #include "src/main/capi/logical_types-c.cpp" #include "src/main/capi/pending-c.cpp" diff --git a/src/duckdb/ub_src_main_http.cpp b/src/duckdb/ub_src_main_http.cpp deleted file mode 100644 index e21e01e04..000000000 --- a/src/duckdb/ub_src_main_http.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/main/http/http_util.cpp" - diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index f8238dab4..0cbee13d3 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -8,6 +8,8 @@ #include "src/optimizer/common_aggregate_optimizer.cpp" +#include "src/optimizer/common_subplan_optimizer.cpp" + #include "src/optimizer/compressed_materialization.cpp" #include "src/optimizer/cse_optimizer.cpp" @@ -34,20 +36,28 @@ #include "src/optimizer/late_materialization.cpp" +#include "src/optimizer/late_materialization_helper.cpp" + #include "src/optimizer/optimizer.cpp" +#include "src/optimizer/join_elimination.cpp" + #include "src/optimizer/regex_range_filter.cpp" #include "src/optimizer/remove_duplicate_groups.cpp" #include "src/optimizer/remove_unused_columns.cpp" +#include "src/optimizer/row_group_pruner.cpp" + #include "src/optimizer/statistics_propagator.cpp" #include "src/optimizer/limit_pushdown.cpp" #include "src/optimizer/topn_optimizer.cpp" +#include "src/optimizer/topn_window_elimination.cpp" + #include "src/optimizer/unnest_rewriter.cpp" #include "src/optimizer/sampling_pushdown.cpp" diff --git a/src/duckdb/ub_src_optimizer_matcher.cpp b/src/duckdb/ub_src_optimizer_matcher.cpp deleted file mode 100644 index 5967c3773..000000000 --- a/src/duckdb/ub_src_optimizer_matcher.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/optimizer/matcher/expression_matcher.cpp" - diff --git a/src/duckdb/ub_src_optimizer_rule.cpp b/src/duckdb/ub_src_optimizer_rule.cpp index 3fa057ede..2a2c56c3c 100644 --- a/src/duckdb/ub_src_optimizer_rule.cpp +++ b/src/duckdb/ub_src_optimizer_rule.cpp @@ -36,3 +36,5 @@ #include "src/optimizer/rule/timestamp_comparison.cpp" +#include "src/optimizer/rule/constant_order_normalization.cpp" + diff --git a/src/duckdb/ub_src_parallel.cpp b/src/duckdb/ub_src_parallel.cpp index eee589714..a95258810 100644 --- a/src/duckdb/ub_src_parallel.cpp +++ b/src/duckdb/ub_src_parallel.cpp @@ -1,3 +1,5 @@ +#include "src/parallel/async_result.cpp" + #include "src/parallel/base_pipeline_event.cpp" #include "src/parallel/meta_pipeline.cpp" diff --git a/src/duckdb/ub_src_parser_query_node.cpp b/src/duckdb/ub_src_parser_query_node.cpp index f0fefe80e..131571749 100644 --- a/src/duckdb/ub_src_parser_query_node.cpp +++ b/src/duckdb/ub_src_parser_query_node.cpp @@ -6,3 +6,5 @@ #include "src/parser/query_node/set_operation_node.cpp" +#include "src/parser/query_node/statement_node.cpp" + diff --git a/src/duckdb/ub_src_parser_transform_constraint.cpp b/src/duckdb/ub_src_parser_transform_constraint.cpp deleted file mode 100644 index 3d74683a3..000000000 --- a/src/duckdb/ub_src_parser_transform_constraint.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#include "src/parser/transform/constraint/transform_constraint.cpp" - diff --git a/src/duckdb/ub_src_planner_binder_query_node.cpp b/src/duckdb/ub_src_planner_binder_query_node.cpp index 2250c80ca..acecbaf63 100644 --- a/src/duckdb/ub_src_planner_binder_query_node.cpp +++ b/src/duckdb/ub_src_planner_binder_query_node.cpp @@ -6,14 +6,12 @@ #include "src/planner/binder/query_node/bind_cte_node.cpp" +#include "src/planner/binder/query_node/bind_statement_node.cpp" + #include "src/planner/binder/query_node/bind_table_macro_node.cpp" #include "src/planner/binder/query_node/plan_query_node.cpp" -#include "src/planner/binder/query_node/plan_recursive_cte_node.cpp" - -#include "src/planner/binder/query_node/plan_cte_node.cpp" - #include "src/planner/binder/query_node/plan_select_node.cpp" #include "src/planner/binder/query_node/plan_setop.cpp" diff --git a/src/duckdb/ub_src_planner_binder_tableref.cpp b/src/duckdb/ub_src_planner_binder_tableref.cpp index b06304d78..641fd88f6 100644 --- a/src/duckdb/ub_src_planner_binder_tableref.cpp +++ b/src/duckdb/ub_src_planner_binder_tableref.cpp @@ -22,23 +22,5 @@ #include "src/planner/binder/tableref/bind_named_parameters.cpp" -#include "src/planner/binder/tableref/plan_basetableref.cpp" - -#include "src/planner/binder/tableref/plan_delimgetref.cpp" - -#include "src/planner/binder/tableref/plan_dummytableref.cpp" - -#include "src/planner/binder/tableref/plan_expressionlistref.cpp" - -#include "src/planner/binder/tableref/plan_column_data_ref.cpp" - #include "src/planner/binder/tableref/plan_joinref.cpp" -#include "src/planner/binder/tableref/plan_subqueryref.cpp" - -#include "src/planner/binder/tableref/plan_table_function.cpp" - -#include "src/planner/binder/tableref/plan_cteref.cpp" - -#include "src/planner/binder/tableref/plan_pivotref.cpp" - diff --git a/src/duckdb/ub_src_planner_expression_binder.cpp b/src/duckdb/ub_src_planner_expression_binder.cpp index 7f3974ea8..609494bd0 100644 --- a/src/duckdb/ub_src_planner_expression_binder.cpp +++ b/src/duckdb/ub_src_planner_expression_binder.cpp @@ -24,6 +24,8 @@ #include "src/planner/expression_binder/order_binder.cpp" +#include "src/planner/expression_binder/try_operator_binder.cpp" + #include "src/planner/expression_binder/projection_binder.cpp" #include "src/planner/expression_binder/relation_binder.cpp" diff --git a/src/duckdb/ub_src_planner_filter.cpp b/src/duckdb/ub_src_planner_filter.cpp index 026be0bcb..683a3b5e5 100644 --- a/src/duckdb/ub_src_planner_filter.cpp +++ b/src/duckdb/ub_src_planner_filter.cpp @@ -1,3 +1,5 @@ +#include "src/planner/filter/bloom_filter.cpp" + #include "src/planner/filter/conjunction_filter.cpp" #include "src/planner/filter/constant_filter.cpp" @@ -14,3 +16,5 @@ #include "src/planner/filter/optional_filter.cpp" +#include "src/planner/filter/selectivity_optional_filter.cpp" + diff --git a/src/duckdb/ub_src_storage.cpp b/src/duckdb/ub_src_storage.cpp index 43ea1eb7a..78f4b3ec4 100644 --- a/src/duckdb/ub_src_storage.cpp +++ b/src/duckdb/ub_src_storage.cpp @@ -4,12 +4,16 @@ #include "src/storage/caching_file_system.cpp" +#include "src/storage/caching_file_system_wrapper.cpp" + #include "src/storage/checkpoint_manager.cpp" #include "src/storage/temporary_memory_manager.cpp" #include "src/storage/block.cpp" +#include "src/storage/block_allocator.cpp" + #include "src/storage/data_pointer.cpp" #include "src/storage/data_table.cpp" diff --git a/src/duckdb/ub_src_storage_statistics.cpp b/src/duckdb/ub_src_storage_statistics.cpp index 637a311d7..5d86cee97 100644 --- a/src/duckdb/ub_src_storage_statistics.cpp +++ b/src/duckdb/ub_src_storage_statistics.cpp @@ -16,3 +16,7 @@ #include "src/storage/statistics/struct_stats.cpp" +#include "src/storage/statistics/geometry_stats.cpp" + +#include "src/storage/statistics/variant_stats.cpp" + diff --git a/src/duckdb/ub_src_storage_table.cpp b/src/duckdb/ub_src_storage_table.cpp index a905148a5..2a3777e94 100644 --- a/src/duckdb/ub_src_storage_table.cpp +++ b/src/duckdb/ub_src_storage_table.cpp @@ -24,6 +24,8 @@ #include "src/storage/table/row_group_collection.cpp" +#include "src/storage/table/row_group_reorderer.cpp" + #include "src/storage/table/row_version_manager.cpp" #include "src/storage/table/scan_state.cpp" @@ -36,3 +38,5 @@ #include "src/storage/table/validity_column_data.cpp" +#include "src/storage/table/variant_column_data.cpp" + diff --git a/src/duckdb/ub_src_storage_table_variant.cpp b/src/duckdb/ub_src_storage_table_variant.cpp new file mode 100644 index 000000000..2bee2f9c0 --- /dev/null +++ b/src/duckdb/ub_src_storage_table_variant.cpp @@ -0,0 +1,4 @@ +#include "src/storage/table/variant/variant_shredding.cpp" + +#include "src/storage/table/variant/variant_unshredding.cpp" +