From ef479ca106fd34285a601270891109fd2e4b2917 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Wed, 1 Oct 2025 23:28:18 -0400 Subject: [PATCH 1/3] add basic structure --- core/qb_structure.py | 399 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 399 insertions(+) create mode 100644 core/qb_structure.py diff --git a/core/qb_structure.py b/core/qb_structure.py new file mode 100644 index 0000000..63c1403 --- /dev/null +++ b/core/qb_structure.py @@ -0,0 +1,399 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional +from enum import Enum +from abc import ABC + +# ============================================================================ +# Base Node Structure +# ============================================================================ + +class NodeKind(Enum): + """Node type enumeration""" + # Query structure + QUERY = "query" + SELECT = "select" + FROM = "from" + WHERE = "where" + GROUP_BY = "group_by" + HAVING = "having" + ORDER_BY = "order_by" + LIMIT = "limit" + OFFSET = "offset" + + # Tables and columns + TABLE = "table" + COLUMN = "column" + SUBQUERY = "subquery" + + # Expressions + LITERAL = "literal" + COMPARISON = "comparison" + LOGICAL = "logical" + ARITHMETIC = "arithmetic" + CASE = "case" + CAST = "cast" + + # Pattern matching + VAR = "var" + VARLIST = "varlist" + + # Other + LIST = "list" + TUPLE = "tuple" + + +@dataclass +class Node(ABC): + """Base class for all nodes""" + type: NodeKind + name: str = "" + alias: Optional[str] = None + value: Any = None + children: List['Node'] = field(default_factory=list) + + +# ============================================================================ +# Specific Node Types +# ============================================================================ + +@dataclass +class TableNode(Node): + """Table reference node""" + type: NodeKind = NodeKind.TABLE + + def __init__(self, name: str, alias: Optional[str] = None, **kwargs): + super().__init__( + type=NodeKind.TABLE, + name=name, + alias=alias, + **kwargs + ) + + def get_identity(self) -> str: + """Get table identity (alias-insensitive)""" + return self.name.lower() + + +@dataclass +class ColumnNode(Node): + """Column reference node""" + type: NodeKind = NodeKind.COLUMN + + def __init__(self, name: str, table_alias: Optional[str] = None, **kwargs): + super().__init__( + type=NodeKind.COLUMN, + name=name, + alias=table_alias, + **kwargs + ) + + def get_identity(self) -> str: + """Get column identity (qualifier-insensitive)""" + return self.name.lower() + + +@dataclass +class ComparisonNode(Node): + """Comparison operation node""" + type: NodeKind = NodeKind.COMPARISON + + def __init__(self, left: 'Node', op: str, right: 'Node', **kwargs): + super().__init__( + type=NodeKind.COMPARISON, + name=op, + children=[left, right], + **kwargs + ) + + +@dataclass +class LogicalNode(Node): + """Logical operation node (AND/OR)""" + type: NodeKind = NodeKind.LOGICAL + + def __init__(self, op: str, terms: List['Node'], **kwargs): + super().__init__( + type=NodeKind.LOGICAL, + name=op, + children=terms, + **kwargs + ) + + +@dataclass +class ArithmeticNode(Node): + """Arithmetic operation node""" + type: NodeKind = NodeKind.ARITHMETIC + + def __init__(self, left: 'Node', op: str, right: 'Node', **kwargs): + super().__init__( + type=NodeKind.ARITHMETIC, + name=op, + children=[left, right], + **kwargs + ) + + +@dataclass +class LiteralNode(Node): + """Literal value node""" + type: NodeKind = NodeKind.LITERAL + + def __init__(self, value: Any, **kwargs): + super().__init__( + type=NodeKind.LITERAL, + value=value, + **kwargs + ) + + +@dataclass +class ListNode(Node): + """List node""" + type: NodeKind = NodeKind.LIST + + def __init__(self, items: List['Node'] = None, **kwargs): + if items is None: + items = [] + super().__init__( + type=NodeKind.LIST, + children=items, + **kwargs + ) + + +@dataclass +class VarNode(Node): + """Pattern variable node""" + type: NodeKind = NodeKind.VAR + + def __init__(self, name: str, var_type: str = 'expr', **kwargs): + super().__init__( + type=NodeKind.VAR, + name=name, + alias=var_type, + **kwargs + ) + + +@dataclass +class VarListNode(Node): + """Pattern variable list node""" + type: NodeKind = NodeKind.VARLIST + + def __init__(self, name: str, items: List['Node'] = None, **kwargs): + if items is None: + items = [] + super().__init__( + type=NodeKind.VARLIST, + name=name, + children=items, + **kwargs + ) + +# ============================================================================ +# Query Structure Nodes +# ============================================================================ + +@dataclass +class SelectNode(Node): + """SELECT clause node""" + type: NodeKind = NodeKind.SELECT + + def __init__(self, items: List['Node'], **kwargs): + super().__init__( + type=NodeKind.SELECT, + children=items, + **kwargs + ) + + +@dataclass +class FromNode(Node): + """FROM clause node""" + type: NodeKind = NodeKind.FROM + + def __init__(self, tables: List['Node'], **kwargs): + super().__init__( + type=NodeKind.FROM, + children=tables, + **kwargs + ) + + +@dataclass +class WhereNode(Node): + """WHERE clause node""" + type: NodeKind = NodeKind.WHERE + + def __init__(self, condition: 'Node', **kwargs): + super().__init__( + type=NodeKind.WHERE, + children=[condition], + **kwargs + ) + + +@dataclass +class QueryNode(Node): + """Query root node""" + type: NodeKind = NodeKind.QUERY + + def __init__(self, select: Optional['Node'] = None, from_: Optional['Node'] = None, + where: Optional['Node'] = None, **kwargs): + children = [] + if select: + children.append(select) + if from_: + children.append(from_) + if where: + children.append(where) + super().__init__( + type=NodeKind.QUERY, + children=children, + **kwargs + ) + + + +if __name__ == '__main__': + + def basic_nodes(): + # Table nodes + employee_table = TableNode('employee', 'e1') + department_table = TableNode('department', 'd1') + + print(f"Table nodes:") + print(f" {employee_table.name} (alias: {employee_table.alias}) -> identity: {employee_table.get_identity()}") + print(f" {department_table.name} (alias: {department_table.alias}) -> identity: {department_table.get_identity()}") + + # Column nodes + id_column = ColumnNode('id', 'e1') + name_column = ColumnNode('name', 'e1') + salary_column = ColumnNode('salary', 'e1') + + print(f"\nColumn nodes:") + print(f" {id_column.name} (qualifier: {id_column.alias}) -> identity: {id_column.get_identity()}") + print(f" {name_column.name} (qualifier: {name_column.alias}) -> identity: {name_column.get_identity()}") + print(f" {salary_column.name} (qualifier: {salary_column.alias}) -> identity: {salary_column.get_identity()}") + + # Literal nodes + age_literal = LiteralNode(25) + salary_literal = LiteralNode(50000) + + print(f"\nLiteral nodes:") + print(f" Age: {age_literal.value}") + print(f" Salary: {salary_literal.value}") + + # Variable nodes for pattern matching + table_var = VarNode('V001', 'table') + column_var = VarNode('V002', 'column') + + print(f"\nVariable nodes:") + print(f" Table var: {table_var.name} (type: {table_var.alias})") + print(f" Column var: {column_var.name} (type: {column_var.alias})") + + + def simple_query(): + # SELECT clause + select_items = [ + ColumnNode('name', 'e1'), + ColumnNode('salary', 'e1') + ] + select_clause = SelectNode(select_items) + + # FROM clause + from_tables = [TableNode('employee', 'e1')] + from_clause = FromNode(from_tables) + + # WHERE clause + where_condition = ComparisonNode( + ColumnNode('age', 'e1'), + '>', + LiteralNode(25) + ) + where_clause = WhereNode(where_condition) + + # Complete query + query = QueryNode(select=select_clause, from_=from_clause, where=where_clause) + + print(f"\nQuery structure:") + print(f" Type: {query.type}") + print(f" Children: {len(query.children)} clauses") + print(f" SELECT items: {len(query.children[0].children)}") + print(f" FROM tables: {len(query.children[1].children)}") + print(f" WHERE conditions: {len(query.children[2].children)}") + + + def arithmetic_expressions(): + # Simple arithmetic + addition = ArithmeticNode(ColumnNode('salary', 'e1'), '+', LiteralNode(1000)) + multiplication = ArithmeticNode(addition, '*', LiteralNode(1.1)) + + # Complex arithmetic + complex_arithmetic = ArithmeticNode( + ArithmeticNode(ColumnNode('base_salary', 'e1'), '+', ColumnNode('bonus', 'e1')), + '*', + ArithmeticNode(LiteralNode(1), '+', ColumnNode('raise_percent', 'e1')) + ) + + print(f"\nArithmetic expressions:") + print(f" Simple: {addition.children[0].name} {addition.name} {addition.children[1].value}") + print(f" Nested: {multiplication.children[0].children[0].name} {multiplication.children[0].name} {multiplication.children[0].children[1].value} {multiplication.name} {multiplication.children[1].value}") + print(f" Complex: {len(complex_arithmetic.children)} terms") + + # Use in WHERE clause + where_with_arithmetic = WhereNode( + ComparisonNode( + complex_arithmetic, + '>', + LiteralNode(100000) + ) + ) + print(f" Used in WHERE: {where_with_arithmetic.children[0].name}") + + + def list_operations(): + # Create a list of columns + columns = [ + ColumnNode('id', 'e1'), + ColumnNode('name', 'e1'), + ColumnNode('salary', 'e1') + ] + column_list = ListNode(columns) + + print(f"\nColumn list:") + print(f" Items: {len(column_list.children)}") + for i, col in enumerate(column_list.children): + print(f" {i+1}. {col.name} (qualifier: {col.alias})") + + # Add more columns + column_list.children.append(ColumnNode('hire_date', 'e1')) + column_list.children.append(ColumnNode('department', 'e1')) + print(f" After adding: {len(column_list.children)} items") + + # Remove a column + column_list.children.pop(1) # Remove 'name' + print(f" After removing 'name': {len(column_list.children)} items") + + # Modify existing column + column_list.children[0].name = 'employee_id' + print(f" Modified first column: {column_list.children[0].name}") + + # Create variable list + var_list = VarListNode('VL001', [ + VarNode('V001', 'column'), + VarNode('V002', 'column') + ]) + print(f"\nVariable list:") + print(f" Name: {var_list.name}") + print(f" Items: {len(var_list.children)}") + for i, var in enumerate(var_list.children): + print(f" {i+1}. {var.name} (type: {var.alias})") + + + examples = [basic_nodes, simple_query, arithmetic_expressions, list_operations] + + for example in examples: + print("="*40) + example() From 139f69a433738c17bb43bfa24a19131734a924d0 Mon Sep 17 00:00:00 2001 From: Qiushi Bai Date: Sun, 5 Oct 2025 23:01:40 -0700 Subject: [PATCH 2/3] Organizing the code; Simplifying the NodeTypes by generalizing different operators into one OperatorNode model; Using Set semantics for most query structure's children; Making sure VarSQL variables use set semantics; Rewriting the tests accordingly; --- core/ast/__init__.py | 49 ++++++ core/ast/node.py | 170 ++++++++++++++++++ core/ast/node_type.py | 32 ++++ core/qb_structure.py | 399 ------------------------------------------ tests/test_ast.py | 337 +++++++++++++++++++++++++++++++++++ 5 files changed, 588 insertions(+), 399 deletions(-) create mode 100644 core/ast/__init__.py create mode 100644 core/ast/node.py create mode 100644 core/ast/node_type.py delete mode 100644 core/qb_structure.py create mode 100644 tests/test_ast.py diff --git a/core/ast/__init__.py b/core/ast/__init__.py new file mode 100644 index 0000000..9d54b22 --- /dev/null +++ b/core/ast/__init__.py @@ -0,0 +1,49 @@ +""" +AST (Abstract Syntax Tree) module for QueryBooster. + +This module provides the node types and classes for representing SQL query structures. +""" + +from .node_type import NodeType +from .node import ( + Node, + TableNode, + SubqueryNode, + ColumnNode, + LiteralNode, + VarNode, + VarSetNode, + OperatorNode, + FunctionNode, + SelectNode, + FromNode, + WhereNode, + GroupByNode, + HavingNode, + OrderByNode, + LimitNode, + OffsetNode, + QueryNode +) + +__all__ = [ + 'NodeType', + 'Node', + 'TableNode', + 'SubqueryNode', + 'ColumnNode', + 'LiteralNode', + 'VarNode', + 'VarSetNode', + 'OperatorNode', + 'FunctionNode', + 'SelectNode', + 'FromNode', + 'WhereNode', + 'GroupByNode', + 'HavingNode', + 'OrderByNode', + 'LimitNode', + 'OffsetNode', + 'QueryNode' +] \ No newline at end of file diff --git a/core/ast/node.py b/core/ast/node.py new file mode 100644 index 0000000..d27584b --- /dev/null +++ b/core/ast/node.py @@ -0,0 +1,170 @@ +from datetime import datetime +from typing import List, Set, Optional +from abc import ABC + +from .node_type import NodeType + +# ============================================================================ +# Base Node Structure +# ============================================================================ + +class Node(ABC): + """Base class for all nodes""" + def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None): + self.type = type + self.children = children if children is not None else set() + + +# ============================================================================ +# Operand Nodes +# ============================================================================ + +class TableNode(Node): + """Table reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.TABLE, **kwargs) + self.name = _name + self.alias = _alias + + +# TODO - including query structure arguments (similar to QueryNode) in constructor. +class SubqueryNode(Node): + """Subquery node""" + def __init__(self, query: 'Node', _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.SUBQUERY, children={query}, **kwargs) + self.alias = _alias + + +class ColumnNode(Node): + """Column reference node""" + def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): + super().__init__(NodeType.COLUMN, **kwargs) + self.name = _name + self.alias = _alias + self.parent_alias = _parent_alias + self.parent = _parent + + +class LiteralNode(Node): + """Literal value node""" + def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): + super().__init__(NodeType.LITERAL, **kwargs) + self.value = _value + + +class VarNode(Node): + """VarSQL variable node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VAR, **kwargs) + self.name = _name + + +class VarSetNode(Node): + """VarSQL variable set node""" + def __init__(self, _name: str, **kwargs): + super().__init__(NodeType.VARSET, **kwargs) + self.name = _name + + +class OperatorNode(Node): + """Operator node""" + def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): + children = [_left, _right] if _right else [_left] + super().__init__(NodeType.OPERATOR, children=children, **kwargs) + self.name = _name + + +class FunctionNode(Node): + """Function call node""" + def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs): + if _args is None: + _args = [] + super().__init__(NodeType.FUNCTION, children=_args, **kwargs) + self.name = _name + + +# ============================================================================ +# Query Structure Nodes +# ============================================================================ + +class SelectNode(Node): + """SELECT clause node""" + def __init__(self, _items: Set['Node'], **kwargs): + super().__init__(NodeType.SELECT, children=_items, **kwargs) + + +# TODO - confine the valid NodeTypes as children of FromNode +class FromNode(Node): + """FROM clause node""" + def __init__(self, _sources: Set['Node'], **kwargs): + super().__init__(NodeType.FROM, children=_sources, **kwargs) + + +class WhereNode(Node): + """WHERE clause node""" + def __init__(self, _predicates: Set['Node'], **kwargs): + super().__init__(NodeType.WHERE, children=_predicates, **kwargs) + + +class GroupByNode(Node): + """GROUP BY clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.GROUP_BY, children=_items, **kwargs) + + +class HavingNode(Node): + """HAVING clause node""" + def __init__(self, _predicates: Set['Node'], **kwargs): + super().__init__(NodeType.HAVING, children=_predicates, **kwargs) + + +class OrderByNode(Node): + """ORDER BY clause node""" + def __init__(self, _items: List['Node'], **kwargs): + super().__init__(NodeType.ORDER_BY, children=_items, **kwargs) + + +class LimitNode(Node): + """LIMIT clause node""" + def __init__(self, _limit: int, **kwargs): + super().__init__(NodeType.LIMIT, **kwargs) + self.limit = _limit + + +class OffsetNode(Node): + """OFFSET clause node""" + def __init__(self, _offset: int, **kwargs): + super().__init__(NodeType.OFFSET, **kwargs) + self.offset = _offset + + +class QueryNode(Node): + """Query root node""" + def __init__(self, + _select: Optional['Node'] = None, + _from: Optional['Node'] = None, + _where: Optional['Node'] = None, + _group_by: Optional['Node'] = None, + _having: Optional['Node'] = None, + _order_by: Optional['Node'] = None, + _limit: Optional['Node'] = None, + _offset: Optional['Node'] = None, + **kwargs): + children = [] + if _select: + children.append(_select) + if _from: + children.append(_from) + if _where: + children.append(_where) + if _group_by: + children.append(_group_by) + if _having: + children.append(_having) + if _order_by: + children.append(_order_by) + if _limit: + children.append(_limit) + if _offset: + children.append(_offset) + super().__init__(NodeType.QUERY, children=children, **kwargs) diff --git a/core/ast/node_type.py b/core/ast/node_type.py new file mode 100644 index 0000000..2bae729 --- /dev/null +++ b/core/ast/node_type.py @@ -0,0 +1,32 @@ +from enum import Enum + +# ============================================================================ +# Node Type Enumeration +# ============================================================================ + +class NodeType(Enum): + """Node type enumeration""" + + # Operands + TABLE = "table" + SUBQUERY = "subquery" + COLUMN = "column" + LITERAL = "literal" + # VarSQL specific + VAR = "var" + VARSET = "varset" + + # Operators + OPERATOR = "operator" + FUNCTION = "function" + + # Query structure + SELECT = "select" + FROM = "from" + WHERE = "where" + GROUP_BY = "group_by" + HAVING = "having" + ORDER_BY = "order_by" + LIMIT = "limit" + OFFSET = "offset" + QUERY = "query" diff --git a/core/qb_structure.py b/core/qb_structure.py deleted file mode 100644 index 63c1403..0000000 --- a/core/qb_structure.py +++ /dev/null @@ -1,399 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, List, Optional -from enum import Enum -from abc import ABC - -# ============================================================================ -# Base Node Structure -# ============================================================================ - -class NodeKind(Enum): - """Node type enumeration""" - # Query structure - QUERY = "query" - SELECT = "select" - FROM = "from" - WHERE = "where" - GROUP_BY = "group_by" - HAVING = "having" - ORDER_BY = "order_by" - LIMIT = "limit" - OFFSET = "offset" - - # Tables and columns - TABLE = "table" - COLUMN = "column" - SUBQUERY = "subquery" - - # Expressions - LITERAL = "literal" - COMPARISON = "comparison" - LOGICAL = "logical" - ARITHMETIC = "arithmetic" - CASE = "case" - CAST = "cast" - - # Pattern matching - VAR = "var" - VARLIST = "varlist" - - # Other - LIST = "list" - TUPLE = "tuple" - - -@dataclass -class Node(ABC): - """Base class for all nodes""" - type: NodeKind - name: str = "" - alias: Optional[str] = None - value: Any = None - children: List['Node'] = field(default_factory=list) - - -# ============================================================================ -# Specific Node Types -# ============================================================================ - -@dataclass -class TableNode(Node): - """Table reference node""" - type: NodeKind = NodeKind.TABLE - - def __init__(self, name: str, alias: Optional[str] = None, **kwargs): - super().__init__( - type=NodeKind.TABLE, - name=name, - alias=alias, - **kwargs - ) - - def get_identity(self) -> str: - """Get table identity (alias-insensitive)""" - return self.name.lower() - - -@dataclass -class ColumnNode(Node): - """Column reference node""" - type: NodeKind = NodeKind.COLUMN - - def __init__(self, name: str, table_alias: Optional[str] = None, **kwargs): - super().__init__( - type=NodeKind.COLUMN, - name=name, - alias=table_alias, - **kwargs - ) - - def get_identity(self) -> str: - """Get column identity (qualifier-insensitive)""" - return self.name.lower() - - -@dataclass -class ComparisonNode(Node): - """Comparison operation node""" - type: NodeKind = NodeKind.COMPARISON - - def __init__(self, left: 'Node', op: str, right: 'Node', **kwargs): - super().__init__( - type=NodeKind.COMPARISON, - name=op, - children=[left, right], - **kwargs - ) - - -@dataclass -class LogicalNode(Node): - """Logical operation node (AND/OR)""" - type: NodeKind = NodeKind.LOGICAL - - def __init__(self, op: str, terms: List['Node'], **kwargs): - super().__init__( - type=NodeKind.LOGICAL, - name=op, - children=terms, - **kwargs - ) - - -@dataclass -class ArithmeticNode(Node): - """Arithmetic operation node""" - type: NodeKind = NodeKind.ARITHMETIC - - def __init__(self, left: 'Node', op: str, right: 'Node', **kwargs): - super().__init__( - type=NodeKind.ARITHMETIC, - name=op, - children=[left, right], - **kwargs - ) - - -@dataclass -class LiteralNode(Node): - """Literal value node""" - type: NodeKind = NodeKind.LITERAL - - def __init__(self, value: Any, **kwargs): - super().__init__( - type=NodeKind.LITERAL, - value=value, - **kwargs - ) - - -@dataclass -class ListNode(Node): - """List node""" - type: NodeKind = NodeKind.LIST - - def __init__(self, items: List['Node'] = None, **kwargs): - if items is None: - items = [] - super().__init__( - type=NodeKind.LIST, - children=items, - **kwargs - ) - - -@dataclass -class VarNode(Node): - """Pattern variable node""" - type: NodeKind = NodeKind.VAR - - def __init__(self, name: str, var_type: str = 'expr', **kwargs): - super().__init__( - type=NodeKind.VAR, - name=name, - alias=var_type, - **kwargs - ) - - -@dataclass -class VarListNode(Node): - """Pattern variable list node""" - type: NodeKind = NodeKind.VARLIST - - def __init__(self, name: str, items: List['Node'] = None, **kwargs): - if items is None: - items = [] - super().__init__( - type=NodeKind.VARLIST, - name=name, - children=items, - **kwargs - ) - -# ============================================================================ -# Query Structure Nodes -# ============================================================================ - -@dataclass -class SelectNode(Node): - """SELECT clause node""" - type: NodeKind = NodeKind.SELECT - - def __init__(self, items: List['Node'], **kwargs): - super().__init__( - type=NodeKind.SELECT, - children=items, - **kwargs - ) - - -@dataclass -class FromNode(Node): - """FROM clause node""" - type: NodeKind = NodeKind.FROM - - def __init__(self, tables: List['Node'], **kwargs): - super().__init__( - type=NodeKind.FROM, - children=tables, - **kwargs - ) - - -@dataclass -class WhereNode(Node): - """WHERE clause node""" - type: NodeKind = NodeKind.WHERE - - def __init__(self, condition: 'Node', **kwargs): - super().__init__( - type=NodeKind.WHERE, - children=[condition], - **kwargs - ) - - -@dataclass -class QueryNode(Node): - """Query root node""" - type: NodeKind = NodeKind.QUERY - - def __init__(self, select: Optional['Node'] = None, from_: Optional['Node'] = None, - where: Optional['Node'] = None, **kwargs): - children = [] - if select: - children.append(select) - if from_: - children.append(from_) - if where: - children.append(where) - super().__init__( - type=NodeKind.QUERY, - children=children, - **kwargs - ) - - - -if __name__ == '__main__': - - def basic_nodes(): - # Table nodes - employee_table = TableNode('employee', 'e1') - department_table = TableNode('department', 'd1') - - print(f"Table nodes:") - print(f" {employee_table.name} (alias: {employee_table.alias}) -> identity: {employee_table.get_identity()}") - print(f" {department_table.name} (alias: {department_table.alias}) -> identity: {department_table.get_identity()}") - - # Column nodes - id_column = ColumnNode('id', 'e1') - name_column = ColumnNode('name', 'e1') - salary_column = ColumnNode('salary', 'e1') - - print(f"\nColumn nodes:") - print(f" {id_column.name} (qualifier: {id_column.alias}) -> identity: {id_column.get_identity()}") - print(f" {name_column.name} (qualifier: {name_column.alias}) -> identity: {name_column.get_identity()}") - print(f" {salary_column.name} (qualifier: {salary_column.alias}) -> identity: {salary_column.get_identity()}") - - # Literal nodes - age_literal = LiteralNode(25) - salary_literal = LiteralNode(50000) - - print(f"\nLiteral nodes:") - print(f" Age: {age_literal.value}") - print(f" Salary: {salary_literal.value}") - - # Variable nodes for pattern matching - table_var = VarNode('V001', 'table') - column_var = VarNode('V002', 'column') - - print(f"\nVariable nodes:") - print(f" Table var: {table_var.name} (type: {table_var.alias})") - print(f" Column var: {column_var.name} (type: {column_var.alias})") - - - def simple_query(): - # SELECT clause - select_items = [ - ColumnNode('name', 'e1'), - ColumnNode('salary', 'e1') - ] - select_clause = SelectNode(select_items) - - # FROM clause - from_tables = [TableNode('employee', 'e1')] - from_clause = FromNode(from_tables) - - # WHERE clause - where_condition = ComparisonNode( - ColumnNode('age', 'e1'), - '>', - LiteralNode(25) - ) - where_clause = WhereNode(where_condition) - - # Complete query - query = QueryNode(select=select_clause, from_=from_clause, where=where_clause) - - print(f"\nQuery structure:") - print(f" Type: {query.type}") - print(f" Children: {len(query.children)} clauses") - print(f" SELECT items: {len(query.children[0].children)}") - print(f" FROM tables: {len(query.children[1].children)}") - print(f" WHERE conditions: {len(query.children[2].children)}") - - - def arithmetic_expressions(): - # Simple arithmetic - addition = ArithmeticNode(ColumnNode('salary', 'e1'), '+', LiteralNode(1000)) - multiplication = ArithmeticNode(addition, '*', LiteralNode(1.1)) - - # Complex arithmetic - complex_arithmetic = ArithmeticNode( - ArithmeticNode(ColumnNode('base_salary', 'e1'), '+', ColumnNode('bonus', 'e1')), - '*', - ArithmeticNode(LiteralNode(1), '+', ColumnNode('raise_percent', 'e1')) - ) - - print(f"\nArithmetic expressions:") - print(f" Simple: {addition.children[0].name} {addition.name} {addition.children[1].value}") - print(f" Nested: {multiplication.children[0].children[0].name} {multiplication.children[0].name} {multiplication.children[0].children[1].value} {multiplication.name} {multiplication.children[1].value}") - print(f" Complex: {len(complex_arithmetic.children)} terms") - - # Use in WHERE clause - where_with_arithmetic = WhereNode( - ComparisonNode( - complex_arithmetic, - '>', - LiteralNode(100000) - ) - ) - print(f" Used in WHERE: {where_with_arithmetic.children[0].name}") - - - def list_operations(): - # Create a list of columns - columns = [ - ColumnNode('id', 'e1'), - ColumnNode('name', 'e1'), - ColumnNode('salary', 'e1') - ] - column_list = ListNode(columns) - - print(f"\nColumn list:") - print(f" Items: {len(column_list.children)}") - for i, col in enumerate(column_list.children): - print(f" {i+1}. {col.name} (qualifier: {col.alias})") - - # Add more columns - column_list.children.append(ColumnNode('hire_date', 'e1')) - column_list.children.append(ColumnNode('department', 'e1')) - print(f" After adding: {len(column_list.children)} items") - - # Remove a column - column_list.children.pop(1) # Remove 'name' - print(f" After removing 'name': {len(column_list.children)} items") - - # Modify existing column - column_list.children[0].name = 'employee_id' - print(f" Modified first column: {column_list.children[0].name}") - - # Create variable list - var_list = VarListNode('VL001', [ - VarNode('V001', 'column'), - VarNode('V002', 'column') - ]) - print(f"\nVariable list:") - print(f" Name: {var_list.name}") - print(f" Items: {len(var_list.children)}") - for i, var in enumerate(var_list.children): - print(f" {i+1}. {var.name} (type: {var.alias})") - - - examples = [basic_nodes, simple_query, arithmetic_expressions, list_operations] - - for example in examples: - print("="*40) - example() diff --git a/tests/test_ast.py b/tests/test_ast.py new file mode 100644 index 0000000..89ae5eb --- /dev/null +++ b/tests/test_ast.py @@ -0,0 +1,337 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from core.ast.node_type import NodeType +from core.ast.node import ( + Node, TableNode, SubqueryNode, ColumnNode, LiteralNode, VarNode, VarSetNode, + OperatorNode, FunctionNode, SelectNode, FromNode, WhereNode, GroupByNode, + HavingNode, OrderByNode, LimitNode, OffsetNode, QueryNode +) + + +def test_operand_nodes(): + """Test all operand node types""" + print("="*50) + print("Testing Operand Nodes") + print("="*50) + + # Test TableNode + employees = TableNode("employees", "e") + departments = TableNode("departments") + + print(f"Table nodes:") + print(f" {employees.name} (alias: {employees.alias}) -> Type: {employees.type}") + print(f" {departments.name} (alias: {departments.alias}) -> Type: {departments.type}") + + # Test ColumnNode + emp_id = ColumnNode("id", _parent_alias="e") + emp_name = ColumnNode("name", "employee_name", "e") + dept_name = ColumnNode("name", _parent_alias="d") + + print(f"\nColumn nodes:") + print(f" {emp_id.name} (parent: {emp_id.parent_alias}) -> Type: {emp_id.type}") + print(f" {emp_name.name} (alias: {emp_name.alias}, parent: {emp_name.parent_alias}) -> Type: {emp_name.type}") + print(f" {dept_name.name} (parent: {dept_name.parent_alias}) -> Type: {dept_name.type}") + + # Test LiteralNode + num_literal = LiteralNode(42) + str_literal = LiteralNode("John Doe") + bool_literal = LiteralNode(True) + null_literal = LiteralNode(None) + + print(f"\nLiteral nodes:") + print(f" {num_literal.value} ({type(num_literal.value).__name__}) -> Type: {num_literal.type}") + print(f" '{str_literal.value}' ({type(str_literal.value).__name__}) -> Type: {str_literal.type}") + print(f" {bool_literal.value} ({type(bool_literal.value).__name__}) -> Type: {bool_literal.type}") + print(f" {null_literal.value} -> Type: {null_literal.type}") + + # Test VarSQL nodes + var_table = VarNode("V001") + var_column = VarNode("V002") + var_set = VarSetNode("VS001") + + print(f"\nVarSQL nodes:") + print(f" Variable {var_table.name} -> Type: {var_table.type}") + print(f" Variable {var_column.name} -> Type: {var_column.type}") + print(f" VarSet {var_set.name} -> Type: {var_set.type}") + + +def test_operator_nodes(): + """Test operator and function nodes""" + print("="*50) + print("Testing Operator and Function Nodes") + print("="*50) + + # Create some operands for testing + age_col = ColumnNode("age") + salary_col = ColumnNode("salary") + age_limit = LiteralNode(30) + salary_limit = LiteralNode(50000) + bonus_col = ColumnNode("bonus") + + # Test comparison operators + age_gt = OperatorNode(age_col, ">", age_limit) + salary_gte = OperatorNode(salary_col, ">=", salary_limit) + name_like = OperatorNode(ColumnNode("name"), "LIKE", LiteralNode("%John%")) + + print(f"Comparison operators:") + print(f" {age_gt.name} operator with {len(age_gt.children)} operands -> Type: {age_gt.type}") + print(f" {salary_gte.name} operator with {len(salary_gte.children)} operands -> Type: {salary_gte.type}") + print(f" {name_like.name} operator with {len(name_like.children)} operands -> Type: {name_like.type}") + + # Test logical operators + and_op = OperatorNode(age_gt, "AND", salary_gte) + or_op = OperatorNode(and_op, "OR", name_like) + not_op = OperatorNode(age_gt, "NOT") # Unary operator + + print(f"\nLogical operators:") + print(f" {and_op.name} operator with {len(and_op.children)} operands -> Type: {and_op.type}") + print(f" {or_op.name} operator with {len(or_op.children)} operands -> Type: {or_op.type}") + print(f" {not_op.name} operator with {len(not_op.children)} operands -> Type: {not_op.type}") + + # Test arithmetic operators + add_op = OperatorNode(salary_col, "+", bonus_col) + mult_op = OperatorNode(add_op, "*", LiteralNode(1.1)) + neg_op = OperatorNode(salary_col, "-") # Unary minus + + print(f"\nArithmetic operators:") + print(f" {add_op.name} operator with {len(add_op.children)} operands -> Type: {add_op.type}") + print(f" {mult_op.name} operator with {len(mult_op.children)} operands -> Type: {mult_op.type}") + print(f" {neg_op.name} operator with {len(neg_op.children)} operands -> Type: {neg_op.type}") + + # Test function nodes + count_func = FunctionNode("COUNT", {ColumnNode("*")}) + max_func = FunctionNode("MAX", {salary_col}) + concat_func = FunctionNode("CONCAT", {ColumnNode("first_name"), LiteralNode(" "), ColumnNode("last_name")}) + now_func = FunctionNode("NOW") # No arguments + + print(f"\nFunction nodes:") + print(f" {count_func.name}() with {len(count_func.children)} args -> Type: {count_func.type}") + print(f" {max_func.name}() with {len(max_func.children)} args -> Type: {max_func.type}") + print(f" {concat_func.name}() with {len(concat_func.children)} args -> Type: {concat_func.type}") + print(f" {now_func.name}() with {len(now_func.children)} args -> Type: {now_func.type}") + + +def test_query_structure_nodes(): + """Test query structure nodes""" + print("="*50) + print("Testing Query Structure Nodes") + print("="*50) + + # Create operands + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + + emp_id = ColumnNode("id", _parent_alias="e") + emp_name = ColumnNode("name", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + dept_name = ColumnNode("name", _parent_alias="d") + + # Test SELECT clause + select_clause = SelectNode({emp_id, emp_name, dept_name}) + print(f"SELECT clause with {len(select_clause.children)} items -> Type: {select_clause.type}") + + # Test FROM clause with JOIN + join_condition = OperatorNode(emp_dept_id, "=", dept_id) + from_clause = FromNode({emp_table, dept_table}) + print(f"FROM clause with {len(from_clause.children)} sources -> Type: {from_clause.type}") + + # Test WHERE clause + age_condition = OperatorNode(ColumnNode("age", _parent_alias="e"), ">", LiteralNode(25)) + salary_condition = OperatorNode(ColumnNode("salary", _parent_alias="e"), ">=", LiteralNode(40000)) + combined_condition = OperatorNode(age_condition, "AND", salary_condition) + where_clause = WhereNode({combined_condition}) + print(f"WHERE clause with {len(where_clause.children)} predicates -> Type: {where_clause.type}") + + # Test GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + print(f"GROUP BY clause with {len(group_by_clause.children)} items -> Type: {group_by_clause.type}") + + # Test HAVING clause + count_condition = OperatorNode(FunctionNode("COUNT", {emp_id}), ">", LiteralNode(5)) + having_clause = HavingNode({count_condition}) + print(f"HAVING clause with {len(having_clause.children)} predicates -> Type: {having_clause.type}") + + # Test ORDER BY clause + order_by_clause = OrderByNode({dept_name, emp_name}) + print(f"ORDER BY clause with {len(order_by_clause.children)} items -> Type: {order_by_clause.type}") + + # Test LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(20) + print(f"LIMIT clause: {limit_clause.limit} -> Type: {limit_clause.type}") + print(f"OFFSET clause: {offset_clause.offset} -> Type: {offset_clause.type}") + + +def test_complete_query(): + """Test building a complete query""" + print("="*50) + print("Testing Complete Query Construction") + print("="*50) + + # Build a complex query: + # SELECT e.name, d.name as dept_name, COUNT(*) as emp_count + # FROM employees e JOIN departments d ON e.department_id = d.id + # WHERE e.salary > 40000 AND e.age < 60 + # GROUP BY d.id, d.name + # HAVING COUNT(*) > 2 + # ORDER BY dept_name, emp_count DESC + # LIMIT 10 OFFSET 5 + + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}) + count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + + # SELECT clause + select_clause = SelectNode({emp_name, dept_name, count_star}) + + # FROM clause (with implicit JOIN logic) + from_clause = FromNode({emp_table, dept_table}) + + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode({where_condition}) + + # GROUP BY clause + group_by_clause = GroupByNode({dept_id, dept_name}) + + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode({having_condition}) + + # ORDER BY clause + order_by_clause = OrderByNode({dept_name, count_alias}) + + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + + # Complete query + query = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + print(f"Complete query built with {len(query.children)} clauses:") + print(f" Query type: {query.type}") + print(f" Total clauses: {len(query.children)}") + + # Analyze query structure + clause_types = [child.type for child in query.children] + print(f" Clause types: {[ct.value for ct in clause_types]}") + + return query + + +def test_varsql_pattern_matching(): + """Test VarSQL pattern matching capabilities""" + print("="*50) + print("Testing VarSQL Pattern Matching") + print("="*50) + + # Pattern: SELECT V1 FROM V2 WHERE V3 op V4 + var_select = VarNode("V1") # Any select item + var_table = VarNode("V2") # Any table + var_left = VarNode("V3") # Left operand of condition + var_op = VarNode("OP") # Any operator + var_right = VarNode("V4") # Right operand of condition + + # Build pattern query + pattern_select = SelectNode({var_select}) + pattern_from = FromNode({var_table}) + pattern_condition = OperatorNode(var_left, "=", var_right) # Could use var_op.name + pattern_where = WhereNode({pattern_condition}) + + pattern_query = QueryNode( + _select=pattern_select, + _from=pattern_from, + _where=pattern_where + ) + + print(f"Pattern query created:") + print(f" SELECT variables: {len(pattern_select.children)}") + print(f" FROM variables: {len(pattern_from.children)}") + print(f" WHERE conditions: {len(pattern_where.children)}") + print(f" Total pattern variables: 4 (V1, V2, V3, V4)") + + # Test VarSet for multiple columns + var_columns = VarSetNode("COLS") + multi_select = SelectNode({var_columns}) + print(f"\nVarSet pattern for multiple columns:") + print(f" VarSet {var_columns.name} can match multiple SELECT items") + + return pattern_query + + +def test_node_relationships(): + """Test node relationships and tree structure""" + print("="*50) + print("Testing Node Relationships") + print("="*50) + + # Build a simple expression tree: (a + b) * c + a = ColumnNode("a") + b = ColumnNode("b") + c = ColumnNode("c") + + add_op = OperatorNode(a, "+", b) + mult_op = OperatorNode(add_op, "*", c) + + print(f"Expression tree: (a + b) * c") + print(f" Root operator: {mult_op.name} ({mult_op.type})") + print(f" Root has {len(mult_op.children)} children") + + # The children are in a set, so we need to handle that + children = list(mult_op.children) + for i, child in enumerate(children): + print(f" Child {i+1}: {child.type}") + if hasattr(child, 'name'): + print(f" Name: {child.name}") + if hasattr(child, 'children') and child.children: + print(f" Has {len(child.children)} sub-children") + + +if __name__ == '__main__': + """Run all test functions""" + test_functions = [ + test_operand_nodes, + test_operator_nodes, + test_query_structure_nodes, + test_complete_query, + test_varsql_pattern_matching, + test_node_relationships + ] + + for test_func in test_functions: + try: + test_func() + print("\n") + except Exception as e: + print(f"ERROR in {test_func.__name__}: {e}") + import traceback + traceback.print_exc() + print("\n") + + print("="*50) + print("All tests completed!") + print("="*50) From 054916010d3638ae8539a17a759a4cfe3eb95816 Mon Sep 17 00:00:00 2001 From: Qiushi Bai Date: Mon, 6 Oct 2025 20:18:45 -0700 Subject: [PATCH 3/3] Addressing comment from Colin by removing hacky file path import in test_ast.py --- tests/test_ast.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_ast.py b/tests/test_ast.py index 89ae5eb..d6006a4 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,10 +1,5 @@ -import sys -import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from core.ast.node_type import NodeType from core.ast.node import ( - Node, TableNode, SubqueryNode, ColumnNode, LiteralNode, VarNode, VarSetNode, + TableNode, ColumnNode, LiteralNode, VarNode, VarSetNode, OperatorNode, FunctionNode, SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, QueryNode )