diff --git a/src/optimizer/join_order/relation_manager.cpp b/src/optimizer/join_order/relation_manager.cpp index e67a76c04752..2a6f673fbfb8 100644 --- a/src/optimizer/join_order/relation_manager.cpp +++ b/src/optimizer/join_order/relation_manager.cpp @@ -1,6 +1,7 @@ #include "duckdb/optimizer/join_order/relation_manager.hpp" #include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/enums/logical_operator_type.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/optimizer/join_order/join_order_optimizer.hpp" #include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" @@ -46,30 +47,49 @@ void RelationManager::AddAggregateOrWindowRelation(LogicalOperator &op, optional void RelationManager::AddRelation(LogicalOperator &op, optional_ptr parent, const RelationStats &stats) { - // if parent is null, then this is a root relation - // if parent is not null, it should have multiple children - D_ASSERT(!parent || parent->children.size() >= 2); + // // if parent is null, then this is a root relation + // // if parent is not null, it should have multiple children + // D_ASSERT(!parent || parent->children.size() >= 2); + // auto relation = make_uniq(op, parent, stats); + // auto relation_id = relations.size(); + // + // auto table_indexes = op.GetTableIndex(); + // if (table_indexes.empty()) { + // // relation represents a non-reorderable relation, most likely a join relation + // // Get the tables referenced in the non-reorderable relation and add them to the relation mapping + // // This should return all table references, even if there are nested non-reorderable joins. + // unordered_set table_references; + // LogicalJoin::GetTableReferences(op, table_references); + // D_ASSERT(table_references.size() > 0); + // for (auto &reference : table_references) { + // D_ASSERT(relation_mapping.find(reference) == relation_mapping.end()); + // relation_mapping[reference] = relation_id; + // } + // } else if (op.type == LogicalOperatorType::LOGICAL_UNNEST) { + // // logical unnest has a logical_unnest index, but other bindings can refer to + // // columns that are not unnested. + // auto bindings = op.GetColumnBindings(); + // for (auto &binding : bindings) { + // relation_mapping[binding.table_index] = relation_id; + // } + // } else { + // // Relations should never return more than 1 table index + // D_ASSERT(table_indexes.size() == 1); + // idx_t table_index = table_indexes.at(0); + // D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end()); + // relation_mapping[table_index] = relation_id; + // } + // relations.push_back(std::move(relation)); + // op.estimated_cardinality = stats.cardinality; + // op.has_estimated_cardinality = true; auto relation = make_uniq(op, parent, stats); auto relation_id = relations.size(); - auto table_indexes = op.GetTableIndex(); - if (table_indexes.empty()) { - // relation represents a non-reorderable relation, most likely a join relation - // Get the tables referenced in the non-reorderable relation and add them to the relation mapping - // This should all table references, even if there are nested non-reorderable joins. - unordered_set table_references; - LogicalJoin::GetTableReferences(op, table_references); - D_ASSERT(table_references.size() > 0); - for (auto &reference : table_references) { - D_ASSERT(relation_mapping.find(reference) == relation_mapping.end()); - relation_mapping[reference] = relation_id; + auto op_bindings = op.GetColumnBindings(); + for (auto &binding : op_bindings) { + if (relation_mapping.find(binding.table_index) == relation_mapping.end()) { + relation_mapping[binding.table_index] = relation_id; } - } else { - // Relations should never return more than 1 table index - D_ASSERT(table_indexes.size() == 1); - idx_t table_index = table_indexes.at(0); - D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end()); - relation_mapping[table_index] = relation_id; } relations.push_back(std::move(relation)); op.estimated_cardinality = stats.cardinality; @@ -319,8 +339,8 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica return can_reorder_left && can_reorder_right; } case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: { - bool can_reorder_right = ExtractJoinRelations(optimizer, *op->children[1], filter_operators, op); bool can_reorder_left = ExtractJoinRelations(optimizer, *op->children[0], filter_operators, op); + bool can_reorder_right = ExtractJoinRelations(optimizer, *op->children[1], filter_operators, op); return can_reorder_left && can_reorder_right; } case LogicalOperatorType::LOGICAL_DUMMY_SCAN: { diff --git a/test/optimizer/join_reorder_optimizer.test b/test/optimizer/join_reorder_optimizer.test index 40fbbff0ccc4..8324115e7c9a 100644 --- a/test/optimizer/join_reorder_optimizer.test +++ b/test/optimizer/join_reorder_optimizer.test @@ -2,6 +2,9 @@ # description: Make sure we can emit a vaild join order by DPhyp if hypergraph is connected # group: [optimizer] +statement ok +pragma enable_verification + statement ok CREATE TABLE t1(c1 int, c2 int, c3 int, c4 int) @@ -33,35 +36,80 @@ statement ok PRAGMA debug_force_no_cross_product=true statement ok -EXPLAIN -SELECT - COUNT(*) -FROM - t1, t2, t3, t4 -WHERE - t1.c1 = t2.c1 AND - t2.c2 = t3.c2 AND +EXPLAIN +SELECT + COUNT(*) +FROM + t1, t2, t3, t4 +WHERE + t1.c1 = t2.c1 AND + t2.c2 = t3.c2 AND t3.c3 = t4.c3 statement ok -EXPLAIN -SELECT - COUNT(*) -FROM - t1, t2, t3, t4 -WHERE - t1.c1 = t2.c1 AND - t2.c2 = t3.c2 AND - t3.c3 = t4.c3 AND +EXPLAIN +SELECT + COUNT(*) +FROM + t1, t2, t3, t4 +WHERE + t1.c1 = t2.c1 AND + t2.c2 = t3.c2 AND + t3.c3 = t4.c3 AND t4.c4 = t1.c4 statement ok -EXPLAIN -SELECT - COUNT(*) -FROM - t1, t2, t3, t4 -WHERE - t1.c1 = t2.c1 AND - t2.c2 = t3.c2 AND +EXPLAIN +SELECT + COUNT(*) +FROM + t1, t2, t3, t4 +WHERE + t1.c1 = t2.c1 AND + t2.c2 = t3.c2 AND t1.c1 + t2.c2 + t3.c3= 3 * t4.c4 + +statement ok +PRAGMA debug_force_no_cross_product=false + +statement ok +with +grid as ( + from (values ('ABC'), ('DEF')) as v(data) + select + unnest(split(data, '')) as letter, + row_number() over () as row_id, + generate_subscripts(split(data, ''), 1) AS col_id, +), +search(row_i, col_i, letter_to_match) as ( + values (0, 0, 'A'), (0, 1, 'B'), +) +from (from grid cross join search) as grid_searches +select exists( + from grid as grid_to_search + where 1=1 + and grid_searches.row_id = grid_to_search.row_id + grid_searches.row_i + and grid_searches.col_id = grid_to_search.col_id + grid_searches.col_i + and grid_searches.letter_to_match = grid_to_search.letter +) + +statement ok +with +grid as ( + from (values ('ABC', 39), ('DEF', 50)) as v(data, row_id) + select + unnest(split(data, '')) as letter, + row_id, + generate_subscripts(split(data, ''), 1) AS col_id, +), +search(row_i, col_i, letter_to_match) as ( + values (0, 0, 'A'), (0, 1, 'B'), +) +from (from grid cross join search) as grid_searches +select exists( + from grid as grid_to_search + where 1=1 + and grid_searches.row_id = grid_to_search.row_id + grid_searches.row_i + and grid_searches.col_id = grid_to_search.col_id + grid_searches.col_i + and grid_searches.letter_to_match = grid_to_search.letter +)