Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import mo_sql_parsing as mosql
from core.ast.node import QueryNode
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate import: QueryNode is imported on line 2 and again on line 4. Remove the duplicate import on line 2.

Suggested change
from core.ast.node import QueryNode

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, let's remove it.

from core.ast.node import (
QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode,
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports: SubqueryNode, VarNode, and VarSetNode are imported but never used in the formatter implementation. Consider removing unused imports.

Suggested change
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode,
OrderByNode, LimitNode, OffsetNode,

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment as in Yihong's PR. We will support Subquery in future PRs when we introduce a complex test case with a subquery.

JoinNode
)
from core.ast.enums import NodeType, JoinType, SortOrder
from core.ast.node import Node

class QueryFormatter:
def format(self, query: QueryNode) -> str:
# [1] AST (QueryNode) -> JSON
json_query = ast_to_json(query)

# [2] Any (JSON) -> str
sql = mosql.format(json_query)

return sql

def ast_to_json(node: QueryNode) -> dict:
"""Convert QueryNode AST to JSON dictionary for mosql"""
result = {}

# process each clause in the query
for child in node.children:
if child.type == NodeType.SELECT:
result['select'] = format_select(child)
elif child.type == NodeType.FROM:
result['from'] = format_from(child)
elif child.type == NodeType.WHERE:
result['where'] = format_where(child)
elif child.type == NodeType.GROUP_BY:
result['groupby'] = format_group_by(child)
elif child.type == NodeType.HAVING:
result['having'] = format_having(child)
elif child.type == NodeType.ORDER_BY:
result['orderby'] = format_order_by(child)
elif child.type == NodeType.LIMIT:
result['limit'] = child.limit
elif child.type == NodeType.OFFSET:
result['offset'] = child.offset

return result


def format_select(select_node: SelectNode) -> list:
"""Format SELECT clause"""
items = []

for child in select_node.children:
if child.type == NodeType.COLUMN:
if child.alias:
items.append({'name': child.alias, 'value': format_expression(child)})
else:
items.append({'value': format_expression(child)})
elif child.type == NodeType.FUNCTION:
func_expr = format_expression(child)
if hasattr(child, 'alias') and child.alias:
items.append({'name': child.alias, 'value': func_expr})
else:
items.append({'value': func_expr})
else:
items.append({'value': format_expression(child)})

return items


def format_from(from_node: FromNode) -> list:
"""Format FROM clause with explicit JOIN support"""
sources = []
children = list(from_node.children)

if not children:
return sources

# Process JoinNode structure
for child in children:
if child.type == NodeType.JOIN:
join_sources = format_join(child)
# format_join returns a list, extend sources with it
if isinstance(join_sources, list):
sources.extend(join_sources)
else:
sources.append(join_sources)
elif child.type == NodeType.TABLE:
sources.append(format_table(child))

return sources


def format_join(join_node: JoinNode) -> list:
"""Format a JOIN node"""
children = list(join_node.children)

if len(children) < 2:
raise ValueError("JoinNode must have at least 2 children (left and right tables)")

left_node = children[0]
right_node = children[1]
join_condition = children[2] if len(children) > 2 else None

result = []

# Format left side (could be a table or nested join)
if left_node.type == NodeType.JOIN:
# Nested join - recursively format
result.extend(format_join(left_node))
elif left_node.type == NodeType.TABLE:
# Simple table - this becomes the FROM table
result.append(format_table(left_node))

# Format the join itself
join_dict = {}

# Map join types to mosql format
join_type_map = {
JoinType.INNER: 'join',
JoinType.LEFT: 'left join',
JoinType.RIGHT: 'right join',
JoinType.FULL: 'full join',
JoinType.CROSS: 'cross join',
}

join_key = join_type_map.get(join_node.join_type, 'join')
join_dict[join_key] = format_table(right_node)

# Add join condition if it exists
if join_condition:
join_dict['on'] = format_expression(join_condition)

result.append(join_dict)

return result


def format_table(table_node: TableNode) -> dict:
"""Format a table reference"""
result = {'value': table_node.name}
if table_node.alias:
result['name'] = table_node.alias
return result


def format_where(where_node: WhereNode) -> dict:
"""Format WHERE clause"""
predicates = list(where_node.children)
if len(predicates) == 1:
return format_expression(predicates[0])
else:
return {'and': [format_expression(p) for p in predicates]}


def format_group_by(group_by_node: GroupByNode) -> list:
"""Format GROUP BY clause"""
return [{'value': format_expression(child)}
for child in group_by_node.children]


def format_having(having_node: HavingNode) -> dict:
"""Format HAVING clause"""
predicates = list(having_node.children)
if len(predicates) == 1:
return format_expression(predicates[0])
else:
return {'and': [format_expression(p) for p in predicates]}


def format_order_by(order_by_node: OrderByNode) -> list:
"""Format ORDER BY clause items."""
items = []

# get all items and their sort orders
sort_orders = []
for child in order_by_node.children:
if child.type == NodeType.ORDER_BY_ITEM:
column = list(child.children)[0]

# Check if the column has an alias
if hasattr(column, 'alias') and column.alias:
item = {'value': column.alias}
else:
item = {'value': format_expression(column)}

sort_order = child.sort
sort_orders.append(sort_order)
else:
# Direct column reference (no OrderByItemNode wrapper)
if hasattr(child, 'alias') and child.alias:
item = {'value': child.alias}
else:
item = {'value': format_expression(child)}

sort_order = SortOrder.ASC
sort_orders.append(sort_order)

items.append((item, sort_order))

# check if all sort orders are the same
all_same = len(set(sort_orders)) == 1
common_sort = sort_orders[0] if all_same else None

# reformat into single sort operator if all items have same sort operator
# ex. ORDER BY dept_name DESC, emp_count DESC -> ORDER BY dept_name, emp_count DESC
result = []
for i, (item, sort_order) in enumerate(items):
if all_same and i == len(items) - 1:
if common_sort != SortOrder.ASC:
item['sort'] = common_sort.value.lower()
elif not all_same:
if sort_order != SortOrder.ASC:
item['sort'] = sort_order.value.lower()

result.append(item)

return result


def format_expression(node: Node):
"""Format an expression node"""
if node.type == NodeType.COLUMN:
if node.parent_alias:
return f"{node.parent_alias}.{node.name}"
return node.name

elif node.type == NodeType.LITERAL:
return node.value

elif node.type == NodeType.FUNCTION:
# format: {'function_name': args}
func_name = node.name.lower()
args = [format_expression(arg) for arg in node.children]
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The function formatting logic assumes functions have at least one argument. SQL functions with no arguments (e.g., NOW(), CURRENT_TIMESTAMP()) would result in {func_name: []}. Consider handling the zero-argument case explicitly, e.g., return {func_name: args} if args else {func_name: None} or verify mosql's expected format for zero-argument functions.

Suggested change
args = [format_expression(arg) for arg in node.children]
args = [format_expression(arg) for arg in node.children]
if not args:
return {func_name: None}

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. We can fix it once we have a test case for it.

return {func_name: args[0] if len(args) == 1 else args}

elif node.type == NodeType.OPERATOR:
# format: {'operator': [left, right]}
op_map = {
'>': 'gt',
'<': 'lt',
'>=': 'gte',
'<=': 'lte',
'=': 'eq',
'!=': 'ne',
'AND': 'and',
'OR': 'or',
}

op_name = op_map.get(node.name.upper(), node.name.lower())
children = list(node.children)

left = format_expression(children[0])

if len(children) == 2:
right = format_expression(children[1])
return {op_name: [left, right]}
else:
# unary operator
return {op_name: left}

elif node.type == NodeType.TABLE:
return format_table(node)

else:
raise ValueError(f"Unsupported node type in expression: {node.type}")
89 changes: 89 additions & 0 deletions tests/test_query_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import mo_sql_parsing as mosql
from core.query_formatter import QueryFormatter
from core.ast.node import (
OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused imports: Several node types are imported but never used in this test file: SubqueryNode, VarNode, VarSetNode. Consider removing unused imports.

Suggested change
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
OrderByNode, LimitNode, OffsetNode, JoinNode

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as the above. We will use them once we have a test case with a subquery.

)
from core.ast.enums import NodeType, JoinType, SortOrder
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'NodeType' is not used.

Suggested change
from core.ast.enums import NodeType, JoinType, SortOrder
from core.ast.enums import JoinType, SortOrder

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Let's remove it.

from data.queries import get_query
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import: get_query is imported but never used in this test file. Consider removing it.

Suggested change
from data.queries import get_query

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Let's remove it.

from re import sub

formatter = QueryFormatter()

def normalize_sql(s):
"""Remove extra whitespace and normalize SQL string to be used in comparisons"""
s = s.strip()
s = sub(r'\s+', ' ', s)

return s

def test_basic_format():
# Construct expected AST
# Tables
emp_table = TableNode("employees", "e")
dept_table = TableNode("departments", "d")
# Columns
emp_name = ColumnNode("name", _parent_alias="e")
emp_salary = ColumnNode("salary", _parent_alias="e")
emp_age = ColumnNode("age", _parent_alias="e")
emp_dept_id = ColumnNode("department_id", _parent_alias="e")

dept_name = ColumnNode("name", _alias="dept_name", _parent_alias="d")
dept_id = ColumnNode("id", _parent_alias="d")

count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")])

# SELECT clause
select_clause = SelectNode([emp_name, dept_name, count_star])
# FROM clause with JOIN
join_condition = OperatorNode(emp_dept_id, "=", dept_id)
join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition)
from_clause = FromNode([join_node])
# WHERE clause
salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000))
age_condition = OperatorNode(emp_age, "<", LiteralNode(60))
where_condition = OperatorNode(salary_condition, "AND", age_condition)
where_clause = WhereNode([where_condition])
# GROUP BY clause
group_by_clause = GroupByNode([dept_id, dept_name])
# HAVING clause
having_condition = OperatorNode(count_star, ">", LiteralNode(2))
having_clause = HavingNode([having_condition])
# ORDER BY clause
order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC)
order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC)
order_by_clause = OrderByNode([order_by_item1, order_by_item2])
# LIMIT and OFFSET
limit_clause = LimitNode(10)
offset_clause = OffsetNode(5)
# Complete query
ast = QueryNode(
_select=select_clause,
_from=from_clause,
_where=where_clause,
_group_by=group_by_clause,
_having=having_clause,
_order_by=order_by_clause,
_limit=limit_clause,
_offset=offset_clause
)

# Construct expected query text
expected_sql = """
SELECT e.name, d.name AS dept_name, COUNT(*) AS emp_count
FROM employees AS e JOIN departments AS d ON e.department_id = d.id
WHERE e.salary > 40000 AND e.age < 60
GROUP BY d.id, d.name
HAVING COUNT(*) > 2
ORDER BY dept_name, emp_count DESC
LIMIT 10 OFFSET 5
"""
expected_sql = expected_sql.strip()
print(mosql.parse(expected_sql))
print(ast)

sql = formatter.format(ast)
sql = sql.strip()

assert normalize_sql(sql) == normalize_sql(expected_sql)
Comment on lines +21 to +89
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Limited test coverage: Consider adding test cases for edge cases such as:

  • SELECT with no FROM clause
  • WHERE with OR conditions
  • Different join types (LEFT, RIGHT, FULL)
  • Nested subqueries
  • Functions with multiple arguments
  • NULL comparisons (IS NULL, IS NOT NULL)

While the current test provides good basic coverage, these cases would ensure the formatter handles various SQL patterns correctly.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. We will add them in the following PRs.

Comment on lines +86 to +89
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Test could be more comprehensive: The test only validates the final SQL output but doesn't verify the intermediate JSON format produced by ast_to_json(). Consider adding an assertion to check the JSON structure matches mosql's expected format, which would help catch issues in the AST-to-JSON conversion step separately from the JSON-to-SQL step.

Copilot uses AI. Check for mistakes.