diff --git a/core/ast/__init__.py b/core/ast/__init__.py new file mode 100644 index 0000000..9d54b22 --- /dev/null +++ b/core/ast/__init__.py @@ -0,0 +1,49 @@ +""" +AST (Abstract Syntax Tree) module for QueryBooster. + +This module provides the node types and classes for representing SQL query structures. +""" + +from .node_type import NodeType +from .node import ( + Node, + TableNode, + SubqueryNode, + ColumnNode, + LiteralNode, + VarNode, + VarSetNode, + OperatorNode, + FunctionNode, + SelectNode, + FromNode, + WhereNode, + GroupByNode, + HavingNode, + OrderByNode, + LimitNode, + OffsetNode, + QueryNode +) + +__all__ = [ + 'NodeType', + 'Node', + 'TableNode', + 'SubqueryNode', + 'ColumnNode', + 'LiteralNode', + 'VarNode', + 'VarSetNode', + 'OperatorNode', + 'FunctionNode', + 'SelectNode', + 'FromNode', + 'WhereNode', + 'GroupByNode', + 'HavingNode', + 'OrderByNode', + 'LimitNode', + 'OffsetNode', + 'QueryNode' +] \ No newline at end of file diff --git a/core/ast/node.py b/core/ast/node.py new file mode 100644 index 0000000..d27584b --- /dev/null +++ b/core/ast/node.py @@ -0,0 +1,170 @@ +from datetime import datetime +from typing import List, Set, Optional +from abc import ABC + +from .node_type import NodeType + +# ============================================================================ +# Base Node Structure +# ============================================================================ + +class Node(ABC): + """Base class for all nodes""" + def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None): + self.type = type + self.children = children if children is not None else set() + + +# ============================================================================ +# Operand Nodes +# ============================================================================ + +class TableNode(Node): + """Table reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.TABLE, **kwargs) + self.name = _name + self.alias = _alias + + +# TODO - including query structure arguments (similar to QueryNode) in constructor. +class SubqueryNode(Node): + """Subquery node""" + def __init__(self, query: 'Node', _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.SUBQUERY, children={query}, **kwargs) + self.alias = _alias + + +class ColumnNode(Node): + """Column reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): + super().__init__(NodeType.COLUMN, **kwargs) + self.name = _name + self.alias = _alias + self.parent_alias = _parent_alias + self.parent = _parent + + +class LiteralNode(Node): + """Literal value node""" + def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): + super().__init__(NodeType.LITERAL, **kwargs) + self.value = _value + + +class VarNode(Node): + """VarSQL variable node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VAR, **kwargs) + self.name = _name + + +class VarSetNode(Node): + """VarSQL variable set node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VARSET, **kwargs) + self.name = _name + + +class OperatorNode(Node): + """Operator node""" + def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): + children = [_left, _right] if _right else [_left] + super().__init__(NodeType.OPERATOR, children=children, **kwargs) + self.name = _name + + +class FunctionNode(Node): + """Function call node""" + def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs): + if _args is None: + _args = [] + super().__init__(NodeType.FUNCTION, children=_args, **kwargs) + self.name = _name + + +# ============================================================================ +# Query Structure Nodes +# ============================================================================ + +class SelectNode(Node): + """SELECT clause node""" + def __init__(self, _items: Set['Node'], **kwargs): + super().__init__(NodeType.SELECT, children=_items, **kwargs) + + +# TODO - confine the valid NodeTypes as children of FromNode +class FromNode(Node): + """FROM clause node""" + def __init__(self, _sources: Set['Node'], **kwargs): + super().__init__(NodeType.FROM, children=_sources, **kwargs) + + +class WhereNode(Node): + """WHERE clause node""" + def __init__(self, _predicates: Set['Node'], **kwargs): + super().__init__(NodeType.WHERE, children=_predicates, **kwargs) + + +class GroupByNode(Node): + """GROUP BY clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.GROUP_BY, children=_items, **kwargs) + + +class HavingNode(Node): + """HAVING clause node""" + def __init__(self, _predicates: Set['Node'], **kwargs): + super().__init__(NodeType.HAVING, children=_predicates, **kwargs) + + +class OrderByNode(Node): + """ORDER BY clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.ORDER_BY, children=_items, **kwargs) + + +class LimitNode(Node): + """LIMIT clause node""" + def __init__(self, _limit: int, **kwargs): + super().__init__(NodeType.LIMIT, **kwargs) + self.limit = _limit + + +class OffsetNode(Node): + """OFFSET clause node""" + def __init__(self, _offset: int, **kwargs): + super().__init__(NodeType.OFFSET, **kwargs) + self.offset = _offset + + +class QueryNode(Node): + """Query root node""" + def __init__(self, + _select: Optional['Node'] = None, + _from: Optional['Node'] = None, + _where: Optional['Node'] = None, + _group_by: Optional['Node'] = None, + _having: Optional['Node'] = None, + _order_by: Optional['Node'] = None, + _limit: Optional['Node'] = None, + _offset: Optional['Node'] = None, + **kwargs): + children = [] + if _select: + children.append(_select) + if _from: + children.append(_from) + if _where: + children.append(_where) + if _group_by: + children.append(_group_by) + if _having: + children.append(_having) + if _order_by: + children.append(_order_by) + if _limit: + children.append(_limit) + if _offset: + children.append(_offset) + super().__init__(NodeType.QUERY, children=children, **kwargs) diff --git a/core/ast/node_type.py b/core/ast/node_type.py new file mode 100644 index 0000000..2bae729 --- /dev/null +++ b/core/ast/node_type.py @@ -0,0 +1,32 @@ +from enum import Enum + +# ============================================================================ +# Node Type Enumeration +# ============================================================================ + +class NodeType(Enum): + """Node type enumeration""" + + # Operands + TABLE = "table" + SUBQUERY = "subquery" + COLUMN = "column" + LITERAL = "literal" + # VarSQL specific + VAR = "var" + VARSET = "varset" + + # Operators + OPERATOR = "operator" + FUNCTION = "function" + + # Query structure + SELECT = "select" + FROM = "from" + WHERE = "where" + GROUP_BY = "group_by" + HAVING = "having" + ORDER_BY = "order_by" + LIMIT = "limit" + OFFSET = "offset" + QUERY = "query" diff --git a/tests/test_ast.py b/tests/test_ast.py new file mode 100644 index 0000000..d6006a4 --- /dev/null +++ b/tests/test_ast.py @@ -0,0 +1,332 @@ +from core.ast.node import ( + TableNode, ColumnNode, LiteralNode, VarNode, VarSetNode, + OperatorNode, FunctionNode, SelectNode, FromNode, WhereNode, GroupByNode, + HavingNode, OrderByNode, LimitNode, OffsetNode, QueryNode +) + + +def test_operand_nodes(): + """Test all operand node types""" + print("="*50) + print("Testing Operand Nodes") + print("="*50) + + # Test TableNode + employees = TableNode("employees", "e") + departments = TableNode("departments") + + print(f"Table nodes:") + print(f" {employees.name} (alias: {employees.alias}) -> Type: {employees.type}") + print(f" {departments.name} (alias: {departments.alias}) -> Type: {departments.type}") + + # Test ColumnNode + emp_id = ColumnNode("id", _parent_alias="e") + emp_name = ColumnNode("name", "employee_name", "e") + dept_name = ColumnNode("name", _parent_alias="d") + + print(f"\nColumn nodes:") + print(f" {emp_id.name} (parent: {emp_id.parent_alias}) -> Type: {emp_id.type}") + print(f" {emp_name.name} (alias: {emp_name.alias}, parent: {emp_name.parent_alias}) -> Type: {emp_name.type}") + print(f" {dept_name.name} (parent: {dept_name.parent_alias}) -> Type: {dept_name.type}") + + # Test LiteralNode + num_literal = LiteralNode(42) + str_literal = LiteralNode("John Doe") + bool_literal = LiteralNode(True) + null_literal = LiteralNode(None) + + print(f"\nLiteral nodes:") + print(f" {num_literal.value} ({type(num_literal.value).__name__}) -> Type: {num_literal.type}") + print(f" '{str_literal.value}' ({type(str_literal.value).__name__}) -> Type: {str_literal.type}") + print(f" {bool_literal.value} ({type(bool_literal.value).__name__}) -> Type: {bool_literal.type}") + print(f" {null_literal.value} -> Type: {null_literal.type}") + + # Test VarSQL nodes + var_table = VarNode("V001") + var_column = VarNode("V002") + var_set = VarSetNode("VS001") + + print(f"\nVarSQL nodes:") + print(f" Variable {var_table.name} -> Type: {var_table.type}") + print(f" Variable {var_column.name} -> Type: {var_column.type}") + print(f" VarSet {var_set.name} -> Type: {var_set.type}") + + +def test_operator_nodes(): + """Test operator and function nodes""" + print("="*50) + print("Testing Operator and Function Nodes") + print("="*50) + + # Create some operands for testing + age_col = ColumnNode("age") + salary_col = ColumnNode("salary") + age_limit = LiteralNode(30) + salary_limit = LiteralNode(50000) + bonus_col = ColumnNode("bonus") + + # Test comparison operators + age_gt = OperatorNode(age_col, ">", age_limit) + salary_gte = OperatorNode(salary_col, ">=", salary_limit) + name_like = OperatorNode(ColumnNode("name"), "LIKE", LiteralNode("%John%")) + + print(f"Comparison operators:") + print(f" {age_gt.name} operator with {len(age_gt.children)} operands -> Type: {age_gt.type}") + print(f" {salary_gte.name} operator with {len(salary_gte.children)} operands -> Type: {salary_gte.type}") + print(f" {name_like.name} operator with {len(name_like.children)} operands -> Type: {name_like.type}") + + # Test logical operators + and_op = OperatorNode(age_gt, "AND", salary_gte) + or_op = OperatorNode(and_op, "OR", name_like) + not_op = OperatorNode(age_gt, "NOT") # Unary operator + + print(f"\nLogical operators:") + print(f" {and_op.name} operator with {len(and_op.children)} operands -> Type: {and_op.type}") + print(f" {or_op.name} operator with {len(or_op.children)} operands -> Type: {or_op.type}") + print(f" {not_op.name} operator with {len(not_op.children)} operands -> Type: {not_op.type}") + + # Test arithmetic operators + add_op = OperatorNode(salary_col, "+", bonus_col) + mult_op = OperatorNode(add_op, "*", LiteralNode(1.1)) + neg_op = OperatorNode(salary_col, "-") # Unary minus + + print(f"\nArithmetic operators:") + print(f" {add_op.name} operator with {len(add_op.children)} operands -> Type: {add_op.type}") + print(f" {mult_op.name} operator with {len(mult_op.children)} operands -> Type: {mult_op.type}") + print(f" {neg_op.name} operator with {len(neg_op.children)} operands -> Type: {neg_op.type}") + + # Test function nodes + count_func = FunctionNode("COUNT", {ColumnNode("*")}) + max_func = FunctionNode("MAX", {salary_col}) + concat_func = FunctionNode("CONCAT", {ColumnNode("first_name"), LiteralNode(" "), ColumnNode("last_name")}) + now_func = FunctionNode("NOW") # No arguments + + print(f"\nFunction nodes:") + print(f" {count_func.name}() with {len(count_func.children)} args -> Type: {count_func.type}") + print(f" {max_func.name}() with {len(max_func.children)} args -> Type: {max_func.type}") + print(f" {concat_func.name}() with {len(concat_func.children)} args -> Type: {concat_func.type}") + print(f" {now_func.name}() with {len(now_func.children)} args -> Type: {now_func.type}") + + +def test_query_structure_nodes(): + """Test query structure nodes""" + print("="*50) + print("Testing Query Structure Nodes") + print("="*50) + + # Create operands + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + + emp_id = ColumnNode("id", _parent_alias="e") + emp_name = ColumnNode("name", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + dept_name = ColumnNode("name", _parent_alias="d") + + # Test SELECT clause + select_clause = SelectNode({emp_id, emp_name, dept_name}) + print(f"SELECT clause with {len(select_clause.children)} items -> Type: {select_clause.type}") + + # Test FROM clause with JOIN + join_condition = OperatorNode(emp_dept_id, "=", dept_id) + from_clause = FromNode({emp_table, dept_table}) + print(f"FROM clause with {len(from_clause.children)} sources -> Type: {from_clause.type}") + + # Test WHERE clause + age_condition = OperatorNode(ColumnNode("age", _parent_alias="e"), ">", LiteralNode(25)) + salary_condition = OperatorNode(ColumnNode("salary", _parent_alias="e"), ">=", LiteralNode(40000)) + combined_condition = OperatorNode(age_condition, "AND", salary_condition) + where_clause = WhereNode({combined_condition}) + print(f"WHERE clause with {len(where_clause.children)} predicates -> Type: {where_clause.type}") + + # Test GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + print(f"GROUP BY clause with {len(group_by_clause.children)} items -> Type: {group_by_clause.type}") + + # Test HAVING clause + count_condition = OperatorNode(FunctionNode("COUNT", {emp_id}), ">", LiteralNode(5)) + having_clause = HavingNode({count_condition}) + print(f"HAVING clause with {len(having_clause.children)} predicates -> Type: {having_clause.type}") + + # Test ORDER BY clause + order_by_clause = OrderByNode({dept_name, emp_name}) + print(f"ORDER BY clause with {len(order_by_clause.children)} items -> Type: {order_by_clause.type}") + + # Test LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(20) + print(f"LIMIT clause: {limit_clause.limit} -> Type: {limit_clause.type}") + print(f"OFFSET clause: {offset_clause.offset} -> Type: {offset_clause.type}") + + +def test_complete_query(): + """Test building a complete query""" + print("="*50) + print("Testing Complete Query Construction") + print("="*50) + + # Build a complex query: + # SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + # FROM employees e JOIN departments d ON e.department_id = d.id + # WHERE e.salary > 40000 AND e.age < 60 + # GROUP BY d.id, d.name + # HAVING COUNT(*) > 2 + # ORDER BY dept_name, emp_count DESC + # LIMIT 10 OFFSET 5 + + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}) + count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + + # SELECT clause + select_clause = SelectNode({emp_name, dept_name, count_star}) + + # FROM clause (with implicit JOIN logic) + from_clause = FromNode({emp_table, dept_table}) + + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode({where_condition}) + + # GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode({having_condition}) + + # ORDER BY clause + order_by_clause = OrderByNode({dept_name, count_alias}) + + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + + # Complete query + query = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + print(f"Complete query built with {len(query.children)} clauses:") + print(f" Query type: {query.type}") + print(f" Total clauses: {len(query.children)}") + + # Analyze query structure + clause_types = [child.type for child in query.children] + print(f" Clause types: {[ct.value for ct in clause_types]}") + + return query + + +def test_varsql_pattern_matching(): + """Test VarSQL pattern matching capabilities""" + print("="*50) + print("Testing VarSQL Pattern Matching") + print("="*50) + + # Pattern: SELECT V1 FROM V2 WHERE V3 op V4 + var_select = VarNode("V1") # Any select item + var_table = VarNode("V2") # Any table + var_left = VarNode("V3") # Left operand of condition + var_op = VarNode("OP") # Any operator + var_right = VarNode("V4") # Right operand of condition + + # Build pattern query + pattern_select = SelectNode({var_select}) + pattern_from = FromNode({var_table}) + pattern_condition = OperatorNode(var_left, "=", var_right) # Could use var_op.name + pattern_where = WhereNode({pattern_condition}) + + pattern_query = QueryNode( + _select=pattern_select, + _from=pattern_from, + _where=pattern_where + ) + + print(f"Pattern query created:") + print(f" SELECT variables: {len(pattern_select.children)}") + print(f" FROM variables: {len(pattern_from.children)}") + print(f" WHERE conditions: {len(pattern_where.children)}") + print(f" Total pattern variables: 4 (V1, V2, V3, V4)") + + # Test VarSet for multiple columns + var_columns = VarSetNode("COLS") + multi_select = SelectNode({var_columns}) + print(f"\nVarSet pattern for multiple columns:") + print(f" VarSet {var_columns.name} can match multiple SELECT items") + + return pattern_query + + +def test_node_relationships(): + """Test node relationships and tree structure""" + print("="*50) + print("Testing Node Relationships") + print("="*50) + + # Build a simple expression tree: (a + b) * c + a = ColumnNode("a") + b = ColumnNode("b") + c = ColumnNode("c") + + add_op = OperatorNode(a, "+", b) + mult_op = OperatorNode(add_op, "*", c) + + print(f"Expression tree: (a + b) * c") + print(f" Root operator: {mult_op.name} ({mult_op.type})") + print(f" Root has {len(mult_op.children)} children") + + # The children are in a set, so we need to handle that + children = list(mult_op.children) + for i, child in enumerate(children): + print(f" Child {i+1}: {child.type}") + if hasattr(child, 'name'): + print(f" Name: {child.name}") + if hasattr(child, 'children') and child.children: + print(f" Has {len(child.children)} sub-children") + + +if __name__ == '__main__': + """Run all test functions""" + test_functions = [ + test_operand_nodes, + test_operator_nodes, + test_query_structure_nodes, + test_complete_query, + test_varsql_pattern_matching, + test_node_relationships + ] + + for test_func in test_functions: + try: + test_func() + print("\n") + except Exception as e: + print(f"ERROR in {test_func.__name__}: {e}") + import traceback + traceback.print_exc() + print("\n") + + print("="*50) + print("All tests completed!") + print("="*50)