From 0c5e9fd9cd0f2f211da75f2f71f6d89831ce93e8 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Fri, 10 Oct 2025 18:13:25 -0400 Subject: [PATCH 1/4] organize queries, setup test for parser --- core/query_parser.py | 22 + data/queries.py | 1121 ++++++++++++++++++++++++++++++++++++ data/rules.py | 136 +++-- tests/test_query_parser.py | 473 +++++++++++++++ 4 files changed, 1707 insertions(+), 45 deletions(-) create mode 100644 core/query_parser.py create mode 100644 data/queries.py create mode 100644 tests/test_query_parser.py diff --git a/core/query_parser.py b/core/query_parser.py new file mode 100644 index 0000000..de8422f --- /dev/null +++ b/core/query_parser.py @@ -0,0 +1,22 @@ + +class QueryParser: + + def parse(self, query: str) -> QueryNode: + # Implement parsing logic using self.rules + pass + + # [1] Call mo_sql_parser + # str -> Any (JSON) + + # [2] Our new code + # Any (JSON) -> AST (QueryNode) + + def format(self, query: QueryNode) -> str: + # Implement formatting logic to convert AST back to SQL string + pass + + # [1] Our new code + # AST (QueryNode) -> JSON + + # [2] Call mo_sql_format + # Any (JSON) -> str \ No newline at end of file diff --git a/data/queries.py b/data/queries.py new file mode 100644 index 0000000..c84ce5d --- /dev/null +++ b/data/queries.py @@ -0,0 +1,1121 @@ + +queries = [ + { + 'id': 1, + 'name': 'Remove Cast Date Match Twice', + 'pattern': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE CAST(DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + AS DATE) IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''', + 'rewrite': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', created_at) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''' + }, + + { + 'id': 2, + 'name': 'Remove Cast Date Match Once', + 'pattern': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''', + 'rewrite': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', created_at) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''' + }, + + { + 'id': 3, + 'name': 'Remove Cast Date No Match', + 'pattern': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', created_at) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''', + 'rewrite': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', created_at) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(text, 'iphone') > 0) + GROUP BY 2; + ''' + }, + + + { + 'id': 4, + 'name': 'Replace Strpos Lower Match', + 'pattern': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE CAST(DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + AS DATE) IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(LOWER(text), 'iphone') > 0) + GROUP BY 2; + ''', + 'rewrite': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE CAST(DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + AS DATE) IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND text ILIKE '%iphone%' + GROUP BY 2; + ''' + }, + + { + 'id': 5, + 'name': 'Replace Strpos Lower No Match', + 'pattern': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND text ILIKE '%iphone%' + GROUP BY 2; + ''', + 'rewrite': ''' + SELECT SUM(1), + CAST(state_name AS TEXT) + FROM tweets + WHERE DATE_TRUNC('QUARTER', + CAST(created_at AS DATE)) + IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND text ILIKE '%iphone%' + GROUP BY 2; + ''' + }, + + + { + 'id': 6, + 'name': 'Remove Self Join Match', + 'pattern': ''' + SELECT e1.name, + e1.age, + e2.salary + FROM employee e1, employee e2 + WHERE e1.id = e2.id + AND e1.age > 17 + AND e2.salary > 35000; + ''', + 'rewrite': ''' + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE 1=1 + AND e1.age > 17 + AND e1.salary > 35000; + ''' + }, + + { + 'id': 7, + 'name': 'Remove Self Join No Match', + 'pattern': ''' + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + ''', + 'rewrite': ''' + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + ''' + }, + + { + 'id': 8, + 'name': 'Remove Self Join Match Simple', + 'pattern': ''' + SELECT e1.age + FROM employee e1, employee e2 + WHERE e1.id = e2.id + AND e1.age > 17; + ''', + 'rewrite': ''' + SELECT e1.age + FROM employee e1 + WHERE 1=1 + AND e1.age > 17; + ''' + }, + + { + 'id': 9, + 'name': 'Remove Self Join Advance Match', + 'pattern': ''' + SELECT e1.name, + e1.age, + e2.salary + FROM employee e1, employee e2 + WHERE e1.id = e2.id + AND e1.age > 17 + AND e2.salary > 35000; + ''', + 'rewrite': ''' + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE 1=1 + AND e1.age > 17 + AND e1.salary > 35000; + ''' + }, + + { + 'id': 10, + 'name': 'Subquery to Join Match 1', + 'pattern': ''' + select empno, firstnme, lastname, phoneno + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS') + and 1=1; + ''', + 'rewrite': ''' + select distinct empno, firstnme, lastname, phoneno + from employee, department + where employee.workdept = department.deptno + and deptname = 'OPERATIONS' + and 1=1; + ''' + }, + + { + 'id': 11, + 'name': 'Subquery to Join Match 2', + 'pattern': ''' + select empno, firstnme, lastname, phoneno + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS') + and age > 17; + ''', + 'rewrite': ''' + select distinct empno, firstnme, lastname, phoneno + from employee, department + where employee.workdept = department.deptno + and deptname = 'OPERATIONS' + and age > 17; + ''' + }, + + { + 'id': 12, + 'name': 'Subquery to Join Match 3', + 'pattern': ''' + select e.empno, e.firstnme, e.lastname, e.phoneno + from employee e + where e.workdept in + (select d.deptno + from department d + where d.deptname = 'OPERATIONS') + and e.age > 17; + ''', + 'rewrite': ''' + select distinct e.empno, e.firstnme, e.lastname, e.phoneno + from employee e, department d + where e.workdept = d.deptno + and d.deptname = 'OPERATIONS' + and e.age > 17; + ''' + }, + + { + 'id': 13, + 'name': 'Join to Filter Match 1', + 'pattern': ''' + SELECT * + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = + allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminrolei2_.admin_role_id = 1 + AND 1=1; + ''', + 'rewrite': ''' + SELECT * + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = + allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + AND 1=1; + ''' + }, + + { + 'id': 14, + 'name': 'Join to Filter Match 2', + 'pattern': ''' + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = + allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendy = 1 + AND adminrolei2_.admin_role_id = 1; + ''', + 'rewrite': ''' + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = + allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + AND adminpermi0_.is_friendy = 1; + ''' + }, + + { + 'id': 15, + 'name': 'Test Rule Wetune 90 Match', + 'pattern': ''' + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 + AND adminrolei2_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + ''', + 'rewrite': ''' + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE adminpermi0_.is_friendly = 1 + AND allroles1_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + ''' + }, + + { + 'id': 16, + 'name': 'Test Rule Calcite PushMinThroughUnion', + 'pattern': ''' + SELECT t.ENAME, + MIN(t.EMPNO) + FROM + (SELECT * + FROM EMP AS EMP + UNION ALL SELECT * + FROM EMP AS EMP) AS t + GROUP BY t.ENAME + ''', + 'rewrite': ''' + SELECT t6.ENAME, MIN(MIN(EMP.EMPNO)) + FROM (SELECT EMP.ENAME, MIN(EMP.EMPNO) + FROM EMP + GROUP BY EMP.ENAME + UNION ALL SELECT EMP.ENAME, MIN(EMP.EMPNO) + FROM EMP + GROUP BY EMP.ENAME) AS t6 + GROUP BY t6.ENAME + ''' + }, + + { + 'id': 17, + 'name': 'Remove Max Distinct', + 'pattern': ''' + SELECT A, MAX(DISTINCT (SELECT B FROM R WHERE C = 0)), D + FROM S; + ''', + 'rewrite': ''' + SELECT A, MAX((SELECT B FROM R WHERE C = 0)), D + FROM S; + ''' + }, + + { + 'id': 18, + 'name': 'Remove 1 Useless InnerJoin', + 'pattern': ''' + SELECT o_auth_applications.id + FROM o_auth_applications + INNER JOIN authorizations + ON o_auth_applications.id = authorizations.o_auth_application_id + WHERE authorizations.user_id = 1465 + ''', + 'rewrite': ''' + SELECT authorizations.o_auth_application_id + FROM authorizations + WHERE authorizations.user_id = 1465 + ''' + }, + + { + 'id': 19, + 'name': 'Stackoverflow 1', + 'pattern': ''' + SELECT DISTINCT my_table.foo, your_table.boo + FROM my_table, your_table + WHERE my_table.num = 1 OR your_table.num = 2 + ''', + 'rewrite': ''' + SELECT + my_table.foo, + your_table.boo + FROM + my_table, + your_table + WHERE + my_table.num = 1 + OR your_table.num = 2 + GROUP BY + my_table.foo, + your_table.boo + ''' + }, + + { + 'id': 20, + 'name': 'Partial Matching Base Case 1', + 'pattern': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 = 's1' OR b.cl1 ='s2' + ''', + 'rewrite': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 IN ('s1', 's2') + ''' + }, + + { + 'id': 21, + 'name': 'Partial Matching Base Case 2', + 'pattern': ''' + SELECT * + FROM b + WHERE + b.cl1 IN ('s1', 's2') OR b.cl1 ='s3' + ''', + 'rewrite': ''' + SELECT * + FROM b + WHERE + b.cl1 IN ('s3', 's1', 's2') + ''' + }, + + { + 'id': 22, + 'name': 'Partial Matching 0', + 'pattern': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 = 's1' OR b.cl1 = 's2' OR b.cl1 = 's3' + ''', + 'rewrite': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 IN ('s1', 's2') OR b.cl1 = 's3' + ''' + }, + + { + 'id': 23, + 'name': 'Partial Matching 1', + 'pattern': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 = 's1' OR b.cl1 = 's2' OR b.cl1 = 's3' + ''', + 'rewrite': ''' + SELECT * + FROM A a + LEFT JOIN B b ON a.id = b.cid + WHERE + b.cl1 IN ('s3', 's1', 's2') + ''' + }, + + { + 'id': 24, + 'name': 'Partial Matching 4', + 'pattern': ''' + select empno, firstname, lastname, phoneno + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS') + and firstname like 'B%' + ''', + 'rewrite': ''' + select distinct empno, firstname, lastname, phoneno + from employee, department + where employee.workdept = department.deptno + and deptname = 'OPERATIONS' + and firstname like 'B%' + ''' + }, + + { + 'id': 25, + 'name': 'Partial Keeps Remaining OR', + 'pattern': ''' + SELECT entities.data + FROM entities + WHERE entities._id IN (SELECT index_users_email._id + FROM index_users_email + WHERE index_users_email.key = 'test') + OR entities._id IN (SELECT index_users_profile_name._id + FROM index_users_profile_name + WHERE index_users_profile_name.key = 'test') + ''', + 'rewrite': ''' + SELECT entities.data + FROM entities + INNER JOIN index_users_email ON index_users_email._id = entities._id + WHERE index_users_email.key = 'test' + OR entities._id IN (SELECT index_users_profile_name._id + FROM index_users_profile_name + WHERE index_users_profile_name.key = 'test') + ''' + }, + + { + 'id': 26, + 'name': 'Partial Keeps Remaining AND', + 'pattern': ''' + SELECT Empno + FROM EMP + WHERE EMPNO > 10 + AND EMPNO <= 10 + AND EMPNAME LIKE '%Jason%' + ''', + 'rewrite': ''' + SELECT Empno + FROM EMP + WHERE FALSE + AND EMPNAME LIKE '%Jason%' + ''' + }, + + { + 'id': 27, + 'name': 'And On True', + 'pattern': ''' + SELECT people.name + FROM people + WHERE 1 AND 1 + ''', + 'rewrite': ''' + SELECT people.name + FROM people + ''' + }, + + { + 'id': 28, + 'name': 'Multiple And On True', + 'pattern': ''' + SELECT name + FROM people + WHERE 1 = 1 AND 2 = 2 + ''', + 'rewrite': ''' + SELECT name + FROM people + ''' + }, + + { + 'id': 29, + 'name': 'Remove Where True', + 'pattern': ''' + SELECT * + FROM Emp + WHERE age > age - 2; + ''', + 'rewrite': ''' + SELECT * + FROM Emp + ''' + }, + + # Rewrite Skips Failed Partial + { + 'id': 30, + 'name': 'Rewrite Skips Failed Partial', + 'pattern': ''' + SELECT * + FROM accounts + WHERE LOWER(accounts.firstname) = LOWER('Sam') + AND accounts.id IN (SELECT addresses.account_id + FROM addresses + WHERE LOWER(addresses.name) = LOWER('Street1')) + AND accounts.id IN (SELECT alternate_ids.account_id + FROM alternate_ids + WHERE alternate_ids.alternate_id_glbl = '5'); + ''', + 'rewrite': ''' + SELECT * + FROM accounts + JOIN addresses ON accounts.id = addresses.account_id + JOIN alternate_ids ON accounts.id = alternate_ids.account_id + WHERE LOWER(accounts.firstname) = LOWER('Sam') + AND LOWER(addresses.name) = LOWER('Street1') + AND alternate_ids.alternate_id_glbl = '5'; + ''' + }, + + { + 'id': 31, + 'name': 'Matching Order', + 'pattern': ''' + SELECT entities.data FROM entities WHERE + entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') + OR + entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') + ''', + 'rewrite': ''' + SELECT entities.data FROM entities INNER JOIN index_users_email ON index_users_email._id = entities._id + WHERE index_users_email.key = 'test' + UNION + SELECT entities.data FROM entities INNER JOIN index_users_profile_name ON index_users_profile_name._id = entities._id + WHERE index_users_profile_name.key = 'test' + ''' + }, + + { + 'id': 32, + 'name': 'No Over Matching', + 'pattern': ''' + SELECT entities.data FROM entities WHERE + entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') + OR + entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') + ''', + 'rewrite': ''' + SELECT + entities.data + FROM + entities + INNER JOIN index_users_email ON index_users_email._id = entities._id + WHERE + index_users_email.key = 'test' + OR entities._id IN ( + SELECT + index_users_profile_name._id + FROM + index_users_profile_name + WHERE + index_users_profile_name.key = 'test' + ) + ''' + }, + + { + 'id': 33, + 'name': 'Full Matching', + 'pattern': ''' + SELECT entities.data FROM entities WHERE entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') + UNION + SELECT entities.data FROM entities WHERE entities._id IN (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') + ''', + 'rewrite': ''' + SELECT entities.data FROM entities INNER JOIN index_users_email ON index_users_email._id = entities._id WHERE index_users_email.key = 'test' + UNION + SELECT entities.data FROM entities INNER JOIN index_users_profile_name ON index_users_profile_name._id = entities._id WHERE index_users_profile_name.key = 'test' + ''' + }, + + { + 'id': 34, + 'name': 'Over Partial Matching', + 'pattern': ''' + SELECT * FROM table_name WHERE (table_name.title = 1 and table_name.grade = 2) OR (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) OR (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3) + ''', + 'rewrite': ''' + SELECT * FROM table_name WHERE (table_name.title = 1 and table_name.grade = 2) OR (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) OR (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3) + ''' + }, + + { + 'id': 35, + 'name': 'Aggregation to Subquery', + 'pattern': ''' +SELECT + t1.CPF, + DATE(t1.data) AS data, + CASE WHEN SUM(CASE WHEN t1.login_ok = true + THEN 1 + ELSE 0 + END) >= 1 + THEN true + ELSE false + END +FROM db_risco.site_rn_login AS t1 +GROUP BY t1.CPF, DATE(t1.data) + ''', + 'rewrite': ''' +SELECT + t1.CPF, + t1.data +FROM ( + SELECT + CPF, + DATE(data) + FROM db_risco.site_rn_login + WHERE login_ok = true +) t1 +GROUP BY t1.CPF, t1.data + ''' + }, + + { + 'id': 36, + 'name': 'Spreadsheet ID 2', + 'pattern': ''' +SELECT * +FROM place +WHERE "select" = TRUE + OR exists (SELECT id + FROM bookmark + WHERE user IN (1,2,3,4) + AND bookmark.place = place.id) + LIMIT 10; + ''', + 'rewrite': ''' +SELECT * +FROM ( + (SELECT * + FROM place + WHERE "select" = True + LIMIT 10) +UNION + (SELECT * + FROM place + WHERE EXISTS + (SELECT 1 + FROM bookmark + WHERE user IN (1, 2, 3, 4) + AND bookmark.place = place.id) + LIMIT 10)) +LIMIT 10 + ''' + }, + + { + 'id': 37, + 'name': 'Spreadsheet ID 3', + 'pattern': ''' +SELECT EMPNO FROM EMP WHERE EMPNO > 10 AND EMPNO <= 10 + ''', + 'rewrite': ''' +SELECT EMPNO FROM EMP WHERE FALSE + ''' + }, + + { + 'id': 38, + 'name': 'Spreadsheet ID 4', + 'pattern': '''SELECT entities.data FROM entities WHERE + entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') + OR + entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') + ''', + 'rewrite': '''SELECT entities.data FROM entities +WHERE entities._id IN + ( SELECT index_users_email._id + FROM index_users_email + WHERE index_users_email.key = 'test' + ) +UNION +SELECT entities.data FROM entities +WHERE entities._id in + ( SELECT index_users_profile_name._id + FROM index_users_profile_name + WHERE index_users_profile_name.key = 'test' + )''' + }, + + { + 'id': 39, + 'name': 'Spreadsheet ID 6', + 'pattern': ''' +SELECT * +FROM + table_name + WHERE + (table_name.title = 1 and table_name.grade = 2) + OR + (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) + OR + (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3) + ''', + 'rewrite': ''' +SELECT * +FROM + table_name + WHERE + 1 = case + when table_name.title = 1 and table_name.grade = 2 then 1 + when table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3 then 1 + when table_name.prog = 1 and table_name.title = 1 and table_name.debt = 3 then 1 + else 0 + end + ''' + }, + + { + 'id': 40, + 'name': 'Spreadsheet ID 7', + 'pattern': ''' +select * from +a +left join b on a.id = b.cid +where +b.cl1 = 's1' +or +b.cl1 ='s2' +or +b.cl1 ='s3' + ''', + 'rewrite': ''' +select * from +a +left join b on a.id = b.cid +where +b.cl1 in ('s1','s2','s3') + ''' + }, + + { + 'id': 41, + 'name': 'Spreadsheet ID 9', + 'pattern': ''' +SELECT DISTINCT my_table.foo +FROM my_table +WHERE my_table.num = 1; + ''', + 'rewrite': ''' +SELECT my_table.foo +FROM my_table +WHERE my_table.num = 1 +GROUP BY my_table.foo; + ''' + }, + + { + 'id': 42, + 'name': 'Spreadsheet ID 10', + 'pattern': ''' +SELECT table1.wpis_id +FROM table1 +WHERE table1.etykieta_id IN ( + SELECT table2.tag_id + FROM table2 + WHERE table2.postac_id = 376476 + ); + ''', + 'rewrite': ''' +SELECT table1.wpis_id +FROM table1 +INNER JOIN table2 on table2.tag_id = table1.etykieta_id +WHERE table2.postac_id = 376476 + ''' + }, + + { + 'id': 43, + 'name': 'Spreadsheet ID 11', + 'pattern': ''' +SELECT historicoestatusrequisicion_id, requisicion_id, estatusrequisicion_id, + comentario, fecha_estatus, usuario_id + FROM historicoestatusrequisicion hist1 + WHERE requisicion_id IN + ( + SELECT requisicion_id FROM historicoestatusrequisicion hist2 + WHERE usuario_id = 27 AND estatusrequisicion_id = 1 + ) + ORDER BY requisicion_id, estatusrequisicion_id + ''', + 'rewrite': ''' +SELECT hist1.historicoestatusrequisicion_id, hist1.requisicion_id, hist1.estatusrequisicion_id, hist1.comentario, hist1.fecha_estatus, hist1.usuario_id + FROM historicoestatusrequisicion hist1 + JOIN historicoestatusrequisicion hist2 ON hist2.requisicion_id = hist1.requisicion_id + WHERE hist2.usuario_id = 27 AND hist2.estatusrequisicion_id = 1 + ORDER BY hist1.requisicion_id, hist1.estatusrequisicion_id + ''' + }, + + { + 'id': 44, + 'name': 'Spreadsheet ID 12', + 'pattern': ''' +SELECT po.id, + SUM(grouped_items.total_quantity) AS order_total_quantity +FROM purchase_orders po +LEFT JOIN ( + SELECT items.purchase_order_id, + SUM(items.quantity) AS item_total + FROM items + GROUP BY items.purchase_order_id +) grouped_items ON po.id = grouped_items.purchase_order_id +WHERE po.shop_id = 195 +GROUP BY po.id + ''', + 'rewrite': ''' +SELECT po.id, + ( + SELECT SUM(items.quantity) + FROM items + WHERE items.purchase_order_id = po.id + GROUP BY items.purchase_order_id + ) AS order_total_quantity +FROM purchase_orders po +WHERE shop_id = 195 +GROUP BY po.id + ''' + }, + + { + 'id': 45, + 'name': 'Spreadsheet ID 15', + 'pattern': ''' +SELECT * +FROM users u +WHERE u.id IN + (SELECT s1.user_id + FROM sessions s1 + WHERE s1.user_id <> 1234 + AND (s1.ip IN + (SELECT s2.ip + FROM sessions s2 + WHERE s2.user_id = 1234 + GROUP BY s2.ip) + OR s1.cookie_identifier IN + (SELECT s3.cookie_identifier + FROM sessions s3 + WHERE s3.user_id = 1234 + GROUP BY s3.cookie_identifier)) + GROUP BY s1.user_id) + ''', + 'rewrite': ''' +SELECT * +FROM users u +WHERE EXISTS ( + SELECT + NULL + FROM sessions s1 + WHERE s1.user_id <> 1234 + AND u.id = s1.user_id + AND EXISTS ( + SELECT + NULL + FROM sessions s2 + WHERE s2.user_id = 1234 + AND (s1.ip = s2.ip + OR s1.cookie_identifier = s2.cookie_identifier + ) + ) + ) + ''' + }, + + { + 'id': 46, + 'name': 'Spreadsheet ID 18', + 'pattern': ''' +SELECT DISTINCT ON (t.playerId) t.gzpId, t.pubCode, t.playerId, + COALESCE (p.preferenceValue,'en'), + s.segmentId +FROM userPlayerIdMap t LEFT JOIN + userPreferences p + ON t.gzpId = p.gzpId LEFT JOIN + segment s + ON t.gzpId = s.gzpId +WHERE t.pubCode IN ('hyrmas','ayqioa','rj49as99') and + t.provider IN ('FCM','ONE_SIGNAL') and + s.segmentId IN (0,1,2,3,4,5,6) and + p.preferenceValue IN ('en','hi') +ORDER BY t.playerId desc; + ''', + 'rewrite': ''' +SELECT t.gzpId, t.pubCode, t.playerId, + COALESCE((SELECT p.preferenceValue + FROM userPreferences p + WHERE t.gzpId = p.gzpId AND + p.preferenceValue IN ('en', 'hi') + LIMIT 1 + ), 'en' + ), + (SELECT s.segmentId + FROM segment s + WHERE t.gzpId = s.gzpId AND + s.segmentId IN (0, 1, 2, 3, 4, 5, 6) + LIMIT 1 + ) +FROM userPlayerIdMap t +WHERE t.pubCode IN ('hyrmas', 'ayqioa', 'rj49as99') and + t.provider IN ('FCM', 'ONE_SIGNAL'); + ''' + }, + + { + 'id': 47, + 'name': 'Spreadsheet ID 20', + 'pattern': ''' +SELECT * FROM (SELECT * FROM (SELECT NULL FROM EMP) WHERE N IS NULL) WHERE N IS NULL + ''', + 'rewrite': ''' +SELECT NULL FROM EMP + ''' + }, + + { + 'id': 48, + 'name': 'PostgreSQL Test', + 'pattern': ''' + SELECT "tweets"."latitude" AS "latitude", + "tweets"."longitude" AS "longitude" + FROM "public"."tweets" "tweets" + WHERE (("tweets"."latitude" >= -90) AND ("tweets"."latitude" <= 80) + AND ((("tweets"."longitude" >= -173.80000000000001) AND ("tweets"."longitude" <= 180)) OR ("tweets"."longitude" IS NULL)) + AND (CAST((DATE_TRUNC( \'day\', CAST("tweets"."created_at" AS DATE) ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL \'1 DAY\')) AS DATE) + = (TIMESTAMP \'2018-04-22 00:00:00.000\')) + AND (STRPOS(CAST(LOWER(CAST(CAST("tweets"."text" AS TEXT) AS TEXT)) AS TEXT),CAST(\'microsoft\' AS TEXT)) > 0)) + GROUP BY 1, 2 + ''', + 'rewrite': ''' + SELECT "tweets"."latitude" AS "latitude", + "tweets"."longitude" AS "longitude" + FROM "public"."tweets" "tweets" + WHERE (("tweets"."latitude" >= -90) AND ("tweets"."latitude" <= 80) + AND ((("tweets"."longitude" >= -173.80000000000001) AND ("tweets"."longitude" <= 180)) OR ("tweets"."longitude" IS NULL)) + AND ((DATE_TRUNC( \'day\', "tweets"."created_at" ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL \'1 DAY\')) + = (TIMESTAMP \'2018-04-22 00:00:00.000\')) + AND "tweets"."text" ILIKE \'%microsoft%\') + GROUP BY 1, 2 + ''' + }, + + { + 'id': 49, + 'name': 'MySQL Test', + 'pattern': '''SELECT `tweets`.`latitude` AS `latitude`, + `tweets`.`longitude` AS `longitude` + FROM `tweets` + WHERE ((ADDDATE(DATE_FORMAT(`tweets`.`created_at`, '%Y-%m-01 00:00:00'), INTERVAL 0 SECOND) = TIMESTAMP('2017-03-01 00:00:00')) + AND (LOCATE('iphone', LOWER(`tweets`.`text`)) > 0)) + GROUP BY 1, 2''', + 'rewrite': '''SELECT `tweets`.`latitude` AS `latitude`, + `tweets`.`longitude` AS `longitude` + FROM `tweets` + WHERE ((DATE_FORMAT(`tweets`.`created_at`, '%Y-%m-01 00:00:00') = TIMESTAMP('2017-03-01 00:00:00')) + AND (LOCATE('iphone', LOWER(`tweets`.`text`)) > 0)) + GROUP BY 1, 2''' + } +] + + +def get_query(query_id: int) -> dict: + return next(filter(lambda x: x['id'] == query_id, queries), None) \ No newline at end of file diff --git a/data/rules.py b/data/rules.py index 0161b6d..154a11d 100644 --- a/data/rules.py +++ b/data/rules.py @@ -18,7 +18,8 @@ 'actions': '', # 'actions_json': '[]', # 'mapping': '{"x": "V1"}', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [17] }, { @@ -34,7 +35,8 @@ 'actions': '', # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\"}", - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [1, 2, 48] }, { @@ -45,7 +47,8 @@ 'constraints': 'TYPE(x)=TEXT', 'rewrite': '', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [48] }, { @@ -61,7 +64,8 @@ 'actions': '', # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\", \"y\": \"V2\"}", - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [4, 48] }, { @@ -88,7 +92,8 @@ 'actions': 'SUBSTITUTE(s1, t2, t1) and\n SUBSTITUTE(p1, t2, t1)', # 'actions_json': "[{\"function\": \"substitute\", \"variables\": [\"VL1\", \"V3\", \"V2\"]}, {\"function\": \"substitute\", \"variables\": [\"VL2\", \"V3\", \"V2\"]}]", # 'mapping': "{\"s1\": \"VL1\", \"p1\": \"VL2\", \"tb1\": \"V1\", \"t1\": \"V2\", \"t2\": \"V3\", \"a1\": \"V4\"}", - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [6, 8, 9] }, { @@ -110,7 +115,8 @@ and <> ''', 'actions': 'SUBSTITUTE(s1, t2, t1) and\n SUBSTITUTE(p1, t2, t1)', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [6, 8, 9] }, { @@ -132,7 +138,8 @@ and <> ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [10, 11, 12, 24] }, { @@ -156,7 +163,8 @@ and <> ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [13, 14] }, { @@ -180,7 +188,8 @@ and <> ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [] }, { @@ -200,7 +209,8 @@ WHERE . = ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [] }, { @@ -222,7 +232,8 @@ AND . = ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [] }, { @@ -244,7 +255,8 @@ AND <> ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [] }, { @@ -264,7 +276,8 @@ WHERE ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [18] }, { @@ -275,7 +288,8 @@ 'constraints': '', 'rewrite': 'FROM ', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [29] }, { @@ -286,7 +300,8 @@ 'constraints': '', 'rewrite': 'SELECT . FROM INNER JOIN ON . = . WHERE . = ', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [25, 30, 31, 32, 33] }, { @@ -297,7 +312,8 @@ 'constraints': '', 'rewrite': 'SELECT <> FROM <> WHERE False', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [26] }, { @@ -318,7 +334,8 @@ AND <> AND <>''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [30] }, { @@ -329,7 +346,8 @@ 'constraints': '', 'rewrite': '''SELECT ., . FROM (SELECT , DATE() FROM WHERE = ) AS GROUP BY <>, .''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [35] }, { @@ -340,7 +358,8 @@ 'constraints': '', 'rewrite': '''SELECT <> FROM ((SELECT <> FROM WHERE LIMIT ) UNION (SELECT <> FROM WHERE EXISTS (SELECT FROM WHERE IN (, , , ) AND <>) LIMIT )) LIMIT ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [36] }, { @@ -351,7 +370,8 @@ 'constraints': '', 'rewrite': '''FALSE''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [37] }, { @@ -362,7 +382,8 @@ 'constraints': '', 'rewrite': '''SELECT <> FROM WHERE . IN (SELECT <> FROM WHERE <>) UNION SELECT <> FROM WHERE . IN (SELECT <> FROM WHERE <>)''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [38] }, { @@ -373,7 +394,8 @@ 'constraints': '', 'rewrite': '''1 = CASE WHEN THEN 1 WHEN THEN 1 WHEN THEN 1 ELSE 0 END''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [39] }, { @@ -384,7 +406,8 @@ 'constraints': '', 'rewrite': '''. IN ('', '', '')''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [40] }, { @@ -395,7 +418,8 @@ 'constraints': '', 'rewrite': '''SELECT FROM WHERE <> GROUP BY ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [41] }, { @@ -406,7 +430,8 @@ 'constraints': '', 'rewrite': '''FROM INNER JOIN ON . = . WHERE <>''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [42] }, { @@ -417,7 +442,8 @@ 'constraints': '', 'rewrite': '''SELECT ., ., ., ., ., . FROM JOIN ON . = . WHERE . = AND . = ORDER BY ., .''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [43] }, { @@ -428,7 +454,8 @@ 'constraints': '', 'rewrite': '''SELECT <>, (SELECT FROM WHERE . = . GROUP BY ) AS FROM WHERE = ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [44] }, { @@ -439,7 +466,8 @@ 'constraints': '', 'rewrite': '''EXISTS (SELECT NULL FROM WHERE <> AND . = . AND EXISTS (SELECT NULL FROM WHERE <> AND (. = . OR . = .)))''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [45] }, { @@ -450,7 +478,8 @@ 'constraints': '', 'rewrite': '''SELECT , , , COALESCE((SELECT . FROM WHERE <> AND <> LIMIT 1), ), (SELECT <> FROM WHERE <> AND . IN (, , , , , , ) LIMIT ) FROM WHERE AND ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [46] }, { @@ -461,7 +490,8 @@ 'constraints': '', 'rewrite': '''SELECT NULL FROM ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [47] }, { @@ -489,7 +519,8 @@ LIMIT ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [15] }, { @@ -525,7 +556,8 @@ LIMIT 50 ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [15] }, { @@ -553,7 +585,8 @@ GROUP BY t6. ''', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [16] }, # MySQL Rules @@ -571,7 +604,8 @@ 'actions': '', # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\"}", - 'database': 'mysql' + 'database': 'mysql', + 'examples': [49] }, { @@ -587,7 +621,8 @@ 'actions': '', # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\", \"y\": \"V2\"}", - 'database': 'mysql' + 'database': 'mysql', + 'examples': [49] }, { @@ -598,7 +633,8 @@ 'constraints': '', 'rewrite': 'SELECT <> FROM <> WHERE <> GROUP BY <>', 'actions': '', - 'database': 'postgresql' + 'database': 'postgresql', + 'examples': [19] }, { 'id': 2258, @@ -608,7 +644,8 @@ 'constraints': '', 'rewrite': ' IN (, )', 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [20, 22, 23, 40] }, { 'id': 2280, @@ -618,7 +655,8 @@ 'constraints': '', 'rewrite': ' IN (, , )', 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [34] }, { 'id': 2259, @@ -628,7 +666,8 @@ 'constraints': '', 'rewrite': ' IN (<>, )', 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [21] }, { 'id': 2260, @@ -638,7 +677,8 @@ 'constraints': '', 'rewrite': ' IN (<>, <>)', 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [] }, { "id": 2261, @@ -648,7 +688,8 @@ 'constraints': '', "rewrite": " IN (<>, <>)", 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [] }, { "id": 2262, @@ -658,7 +699,8 @@ 'constraints': '', "rewrite": "SELECT DISTINCT , , , FROM , WHERE . = . AND <>", 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [24] }, { "id": 2263, @@ -668,7 +710,8 @@ 'constraints': '', "rewrite": "FROM ", 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [27] }, { "id": 2264, @@ -678,7 +721,8 @@ 'constraints': '', "rewrite": "FROM ", 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [28] }, { "id": 2265, @@ -688,7 +732,8 @@ 'constraints': '', "rewrite": "SELECT <> FROM WHERE . IN (SELECT <> FROM WHERE <>) UNION SELECT <> FROM WHERE . IN (SELECT <> FROM WHERE <>)", 'actions': '', - 'database': 'mysql' + 'database': 'mysql', + 'examples': [38] } ] @@ -713,6 +758,7 @@ def get_rule(key: str) -> dict: 'actions_json': json.loads(rule['actions_json']), 'mapping': json.loads(rule['mapping']), 'database': rule['database'] + 'examples': rule['examples'] } # return a list of rules (json attributes are in str) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py new file mode 100644 index 0000000..c768e49 --- /dev/null +++ b/tests/test_query_parser.py @@ -0,0 +1,473 @@ +import mo_sql_parsing as mosql +from core.qb_parser import QBParser, parse_sql_to_qb_ast +from core.ast.node import ( + QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, LimitNode, OffsetNode +) +from core.ast.node_type import NodeType +from data.queries import get_query + + +def test_parse_1(): + """ + SELECT clause: + - SUM() aggregate function + - CAST(column AS TEXT) function + - Column references (state_name) + + FROM clause: + - Simple table reference (tweets) + + WHERE clause: + - CAST() function with nested expressions + - DATE_TRUNC() function with string literal and column + - Nested CAST() functions + - IN operator with list of TIMESTAMP literals + - AND logical operator + - STRPOS() function with column and string literal + - Comparison operator (>) + - Numeric literal (0) + + GROUP BY clause: + - Numeric literal reference (2) + """ + + query = get_query(1) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 2 + + # Check FROM clause + from_clause = None + for child in qb_ast.children: + if child.type == NodeType.FROM: + from_clause = child + break + + # assert from_clause is not None + # table_node = list(from_clause.children)[0] + # assert isinstance(table_node, TableNode) + # assert table_node.name == "tweets" + + # Check WHERE clause + where_clause = None + for child in qb_ast.children: + if child.type == NodeType.WHERE: + where_clause = child + break + + # assert where_clause is not None + # assert len(where_clause.children) == 1 + + # Check GROUP BY clause + group_by_clause = None + for child in qb_ast.children: + if child.type == NodeType.GROUP_BY: + group_by_clause = child + break + + # assert group_by_clause is not None + # assert len(group_by_clause.children) == 1 + + +def test_parse_2(): + """ + SELECT clause: + - SUM() aggregate function + - CAST(column AS TEXT) function + - Column references (state_name) + + FROM clause: + - Simple table reference (tweets) + + WHERE clause: + - DATE_TRUNC() function with string literal and column + - CAST() function with DATE type + - IN operator with list of TIMESTAMP literals + - AND logical operator + - STRPOS() function with column and string literal + - Comparison operator (>) + - Numeric literal (0) + + GROUP BY clause: + - Numeric literal reference (2) + """ + + query = get_query(2) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Find the STRPOS condition in WHERE clause + where_clause = None + for child in qb_ast.children: + if child.type == NodeType.WHERE: + where_clause = child + break + + # assert where_clause is not None + # condition = list(where_clause.children)[0] + # assert isinstance(condition, OperatorNode) + # assert condition.name == "AND" + + # The condition should have two operands + # operands = list(condition.children) + # assert len(operands) == 2 + + +def test_parse_3(): + """ + SELECT clause: + - SUM() aggregate function + - CAST(column AS TEXT) function + - Column references (state_name) + + FROM clause: + - Simple table reference (tweets) + + WHERE clause: + - DATE_TRUNC() function with column reference + - IN operator with list of TIMESTAMP literals + - AND logical operator + - STRPOS() function with column and string literal + - Comparison operator (>) + - Numeric literal (0) + + GROUP BY clause: + - Numeric literal reference (2) + """ + + query = get_query(3) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause has 3 items + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 3 + + +def test_parse_4(): + """ + SELECT clause: + - SUM() aggregate function + - CAST(column AS TEXT) function + - Column references (state_name) + + FROM clause: + - Simple table reference (tweets) + + WHERE clause: + - CAST() function with nested expressions + - DATE_TRUNC() function with string literal and column + - Nested CAST() functions + - IN operator with list of TIMESTAMP literals + - AND logical operator + - STRPOS() function with nested LOWER() function + - LOWER() function with column reference + - Comparison operator (>) + - Numeric literal (0) + + GROUP BY clause: + - Numeric literal reference (2) + """ + + query = get_query(4) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check WHERE clause has IN condition + where_clause = None + for child in qb_ast.children: + if child.type == NodeType.WHERE: + where_clause = child + break + + # assert where_clause is not None + # condition = list(where_clause.children)[0] + # assert isinstance(condition, OperatorNode) + # assert condition.name == "IN" + + +def test_parse_5(): + """ + SELECT clause: + - SUM() aggregate function + - CAST(column AS TEXT) function + - Column references (state_name) + + FROM clause: + - Simple table reference (tweets) + + WHERE clause: + - DATE_TRUNC() function with string literal and column + - CAST() function with DATE type + - IN operator with list of TIMESTAMP literals + - AND logical operator + - ILIKE operator with wildcard pattern + - String literal with wildcards ('%iphone%') + - Column reference (text) + + GROUP BY clause: + - Numeric literal reference (2) + """ + + query = get_query(5) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check FROM clause has two table references + from_clause = None + for child in qb_ast.children: + if child.type == NodeType.FROM: + from_clause = child + break + + # assert from_clause is not None + # assert len(from_clause.children) == 2 + + # Both should be employees table with different aliases + # table_names = [table.name for table in from_clause.children] + # assert table_names.count("employees") == 2 + + +def test_parse_6(): + """ + SELECT clause: + - Column references (e1.name, e1.age, e2.salary) + - Table alias usage in column references + - Multiple SELECT items + + FROM clause: + - Multiple table references with aliases + - Same table with different aliases (e1, e2) + - Table aliases (employee e1, employee e2) + + WHERE clause: + - AND logical operator + - Equality operator (=) + - Qualified column references (e1.id, e2.id) + - Self-join condition (e1.id = e2.id) + - Additional conditions (e1.age > 17, e2.salary > 35000) + - Comparison operators (>, >) + - Numeric literals (17, 35000) + """ + + query = get_query(6) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check all expected clauses are present + clause_types = [child.type for child in qb_ast.children] + expected_clauses = [ + NodeType.SELECT, NodeType.FROM, NodeType.WHERE, + NodeType.GROUP_BY, NodeType.HAVING, NodeType.ORDER_BY, NodeType.LIMIT + ] + + # for expected_clause in expected_clauses: + # assert expected_clause in clause_types + + # Check LIMIT value + limit_clause = None + for child in qb_ast.children: + if child.type == NodeType.LIMIT: + limit_clause = child + break + + # assert limit_clause is not None + # assert limit_clause.limit == 10 + + +def test_parse_7(): + """ + SELECT clause: + - Column references (e1.name, e1.age, e1.salary) + - Table alias usage in column references + - Multiple SELECT items + + FROM clause: + - Simple table reference with alias (employee e1) + + WHERE clause: + - AND logical operator + - Comparison operators (>, >) + - Qualified column references (e1.age, e1.salary) + - Numeric literals (17, 35000) + """ + + query = get_query(7) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause has 3 items with CAST functions + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 3 + + # All SELECT items should be functions (CAST operations) + # function_count = sum(1 for item in select_clause.children if isinstance(item, FunctionNode)) + # assert function_count == 3 + + +def test_parse_8(): + """ + SELECT clause: + - Column reference (e1.age) + - Table alias usage in column reference + + FROM clause: + - Multiple table references with aliases + - Same table with different aliases (e1, e2) + - Table aliases (employee e1, employee e2) + + WHERE clause: + - AND logical operator + - Equality operator (=) + - Qualified column references (e1.id, e2.id) + - Self-join condition (e1.id = e2.id) + - Additional condition (e1.age > 17) + - Comparison operator (>) + - Numeric literal (17) + """ + + query = get_query(8) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause has date functions + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 3 + + # Check for DATE_TRUNC function + # function_names = [item.name for item in select_clause.children if isinstance(item, FunctionNode)] + # assert "DATE_TRUNC" in function_names + + +def test_parse_9(): + """ + SELECT clause: + - Column references (e1.name, e1.age, e2.salary) + - Table alias usage in column references + - Multiple SELECT items + + FROM clause: + - Multiple table references with aliases + - Same table with different aliases (e1, e2) + - Table aliases (employee e1, employee e2) + + WHERE clause: + - AND logical operator + - Equality operator (=) + - Qualified column references (e1.id, e2.id) + - Self-join condition (e1.id = e2.id) + - Additional conditions (e1.age > 17, e2.salary > 35000) + - Comparison operators (>, >) + - Numeric literals (17, 35000) + """ + + query = get_query(9) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause has string functions + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 4 + + # Check for string functions + # function_names = [item.name for item in select_clause.children if isinstance(item, FunctionNode)] + # expected_functions = ["UPPER", "LOWER", "CONCAT", "SUBSTRING"] + # for func in expected_functions: + # assert func in function_names + + +def test_parse_10(): + """ + SELECT clause: + - Column references (empno, firstnme, lastname, phoneno) + - Multiple SELECT items + + FROM clause: + - Simple table reference (employee) + + WHERE clause: + - IN operator with subquery + - Column reference (workdept) + - AND logical operator + - Numeric literal (1) + + Subquery (within IN): + - SELECT clause with column reference (deptno) + - FROM clause with table reference (department) + - WHERE clause with equality condition + - Column reference (deptname) + - String literal ('OPERATIONS') + """ + + query = get_query(10) + sql = query['pattern'] + + qb_ast = parse_sql_to_qb_ast(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check SELECT clause has arithmetic operations + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break + + # assert select_clause is not None + # assert len(select_clause.children) == 4 + + # Check for arithmetic operators + # operator_count = sum(1 for item in select_clause.children if isinstance(item, OperatorNode)) + # assert operator_count >= 3 # Should have multiple arithmetic operations From cccba36288bf713b2d639fe342022ccc8e5b6840 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Fri, 10 Oct 2025 19:48:23 -0400 Subject: [PATCH 2/4] fix syntax --- data/rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/rules.py b/data/rules.py index 154a11d..d76c74e 100644 --- a/data/rules.py +++ b/data/rules.py @@ -757,7 +757,7 @@ def get_rule(key: str) -> dict: 'actions': rule['actions'], 'actions_json': json.loads(rule['actions_json']), 'mapping': json.loads(rule['mapping']), - 'database': rule['database'] + 'database': rule['database'], 'examples': rule['examples'] } From 54a06cbe4eec68de2cce1e099c69d8bb87d19be8 Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Fri, 10 Oct 2025 20:37:16 -0400 Subject: [PATCH 3/4] fix tests --- core/query_parser.py | 1 + data/queries.py | 231 ++++---------------- data/rules.py | 68 +++--- tests/test_query_parser.py | 435 +++++++++++-------------------------- 4 files changed, 209 insertions(+), 526 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index de8422f..1ac3796 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,3 +1,4 @@ +from core.ast.node import QueryNode class QueryParser: diff --git a/data/queries.py b/data/queries.py index c84ce5d..dd1c8ed 100644 --- a/data/queries.py +++ b/data/queries.py @@ -1,4 +1,3 @@ - queries = [ { 'id': 1, @@ -89,7 +88,6 @@ ''' }, - { 'id': 4, 'name': 'Replace Strpos Lower Match', @@ -152,7 +150,6 @@ ''' }, - { 'id': 6, 'name': 'Remove Self Join Match', @@ -216,29 +213,6 @@ { 'id': 9, - 'name': 'Remove Self Join Advance Match', - 'pattern': ''' - SELECT e1.name, - e1.age, - e2.salary - FROM employee e1, employee e2 - WHERE e1.id = e2.id - AND e1.age > 17 - AND e2.salary > 35000; - ''', - 'rewrite': ''' - SELECT e1.name, - e1.age, - e1.salary - FROM employee e1 - WHERE 1=1 - AND e1.age > 17 - AND e1.salary > 35000; - ''' - }, - - { - 'id': 10, 'name': 'Subquery to Join Match 1', 'pattern': ''' select empno, firstnme, lastname, phoneno @@ -259,7 +233,7 @@ }, { - 'id': 11, + 'id': 10, 'name': 'Subquery to Join Match 2', 'pattern': ''' select empno, firstnme, lastname, phoneno @@ -280,7 +254,7 @@ }, { - 'id': 12, + 'id': 11, 'name': 'Subquery to Join Match 3', 'pattern': ''' select e.empno, e.firstnme, e.lastname, e.phoneno @@ -301,7 +275,7 @@ }, { - 'id': 13, + 'id': 12, 'name': 'Join to Filter Match 1', 'pattern': ''' SELECT * @@ -326,7 +300,7 @@ }, { - 'id': 14, + 'id': 13, 'name': 'Join to Filter Match 2', 'pattern': ''' SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ @@ -351,7 +325,7 @@ }, { - 'id': 15, + 'id': 14, 'name': 'Test Rule Wetune 90 Match', 'pattern': ''' SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, @@ -383,7 +357,7 @@ }, { - 'id': 16, + 'id': 15, 'name': 'Test Rule Calcite PushMinThroughUnion', 'pattern': ''' SELECT t.ENAME, @@ -408,7 +382,7 @@ }, { - 'id': 17, + 'id': 16, 'name': 'Remove Max Distinct', 'pattern': ''' SELECT A, MAX(DISTINCT (SELECT B FROM R WHERE C = 0)), D @@ -421,7 +395,7 @@ }, { - 'id': 18, + 'id': 17, 'name': 'Remove 1 Useless InnerJoin', 'pattern': ''' SELECT o_auth_applications.id @@ -438,7 +412,7 @@ }, { - 'id': 19, + 'id': 18, 'name': 'Stackoverflow 1', 'pattern': ''' SELECT DISTINCT my_table.foo, your_table.boo @@ -462,7 +436,7 @@ }, { - 'id': 20, + 'id': 19, 'name': 'Partial Matching Base Case 1', 'pattern': ''' SELECT * @@ -481,7 +455,7 @@ }, { - 'id': 21, + 'id': 20, 'name': 'Partial Matching Base Case 2', 'pattern': ''' SELECT * @@ -498,7 +472,7 @@ }, { - 'id': 22, + 'id': 21, 'name': 'Partial Matching 0', 'pattern': ''' SELECT * @@ -517,26 +491,7 @@ }, { - 'id': 23, - 'name': 'Partial Matching 1', - 'pattern': ''' - SELECT * - FROM A a - LEFT JOIN B b ON a.id = b.cid - WHERE - b.cl1 = 's1' OR b.cl1 = 's2' OR b.cl1 = 's3' - ''', - 'rewrite': ''' - SELECT * - FROM A a - LEFT JOIN B b ON a.id = b.cid - WHERE - b.cl1 IN ('s3', 's1', 's2') - ''' - }, - - { - 'id': 24, + 'id': 22, 'name': 'Partial Matching 4', 'pattern': ''' select empno, firstname, lastname, phoneno @@ -557,7 +512,7 @@ }, { - 'id': 25, + 'id': 23, 'name': 'Partial Keeps Remaining OR', 'pattern': ''' SELECT entities.data @@ -581,7 +536,7 @@ }, { - 'id': 26, + 'id': 24, 'name': 'Partial Keeps Remaining AND', 'pattern': ''' SELECT Empno @@ -599,7 +554,7 @@ }, { - 'id': 27, + 'id': 25, 'name': 'And On True', 'pattern': ''' SELECT people.name @@ -613,7 +568,7 @@ }, { - 'id': 28, + 'id': 26, 'name': 'Multiple And On True', 'pattern': ''' SELECT name @@ -627,7 +582,7 @@ }, { - 'id': 29, + 'id': 27, 'name': 'Remove Where True', 'pattern': ''' SELECT * @@ -640,9 +595,8 @@ ''' }, - # Rewrite Skips Failed Partial { - 'id': 30, + 'id': 28, 'name': 'Rewrite Skips Failed Partial', 'pattern': ''' SELECT * @@ -667,53 +621,7 @@ }, { - 'id': 31, - 'name': 'Matching Order', - 'pattern': ''' - SELECT entities.data FROM entities WHERE - entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') - OR - entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') - ''', - 'rewrite': ''' - SELECT entities.data FROM entities INNER JOIN index_users_email ON index_users_email._id = entities._id - WHERE index_users_email.key = 'test' - UNION - SELECT entities.data FROM entities INNER JOIN index_users_profile_name ON index_users_profile_name._id = entities._id - WHERE index_users_profile_name.key = 'test' - ''' - }, - - { - 'id': 32, - 'name': 'No Over Matching', - 'pattern': ''' - SELECT entities.data FROM entities WHERE - entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') - OR - entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') - ''', - 'rewrite': ''' - SELECT - entities.data - FROM - entities - INNER JOIN index_users_email ON index_users_email._id = entities._id - WHERE - index_users_email.key = 'test' - OR entities._id IN ( - SELECT - index_users_profile_name._id - FROM - index_users_profile_name - WHERE - index_users_profile_name.key = 'test' - ) - ''' - }, - - { - 'id': 33, + 'id': 29, 'name': 'Full Matching', 'pattern': ''' SELECT entities.data FROM entities WHERE entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') @@ -728,7 +636,7 @@ }, { - 'id': 34, + 'id': 30, 'name': 'Over Partial Matching', 'pattern': ''' SELECT * FROM table_name WHERE (table_name.title = 1 and table_name.grade = 2) OR (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) OR (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3) @@ -739,7 +647,7 @@ }, { - 'id': 35, + 'id': 31, 'name': 'Aggregation to Subquery', 'pattern': ''' SELECT @@ -771,7 +679,7 @@ }, { - 'id': 36, + 'id': 32, 'name': 'Spreadsheet ID 2', 'pattern': ''' SELECT * @@ -804,7 +712,7 @@ }, { - 'id': 37, + 'id': 33, 'name': 'Spreadsheet ID 3', 'pattern': ''' SELECT EMPNO FROM EMP WHERE EMPNO > 10 AND EMPNO <= 10 @@ -815,58 +723,7 @@ }, { - 'id': 38, - 'name': 'Spreadsheet ID 4', - 'pattern': '''SELECT entities.data FROM entities WHERE - entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') - OR - entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test') - ''', - 'rewrite': '''SELECT entities.data FROM entities -WHERE entities._id IN - ( SELECT index_users_email._id - FROM index_users_email - WHERE index_users_email.key = 'test' - ) -UNION -SELECT entities.data FROM entities -WHERE entities._id in - ( SELECT index_users_profile_name._id - FROM index_users_profile_name - WHERE index_users_profile_name.key = 'test' - )''' - }, - - { - 'id': 39, - 'name': 'Spreadsheet ID 6', - 'pattern': ''' -SELECT * -FROM - table_name - WHERE - (table_name.title = 1 and table_name.grade = 2) - OR - (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) - OR - (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3) - ''', - 'rewrite': ''' -SELECT * -FROM - table_name - WHERE - 1 = case - when table_name.title = 1 and table_name.grade = 2 then 1 - when table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3 then 1 - when table_name.prog = 1 and table_name.title = 1 and table_name.debt = 3 then 1 - else 0 - end - ''' - }, - - { - 'id': 40, + 'id': 34, 'name': 'Spreadsheet ID 7', 'pattern': ''' select * from @@ -889,7 +746,7 @@ }, { - 'id': 41, + 'id': 35, 'name': 'Spreadsheet ID 9', 'pattern': ''' SELECT DISTINCT my_table.foo @@ -905,7 +762,7 @@ }, { - 'id': 42, + 'id': 36, 'name': 'Spreadsheet ID 10', 'pattern': ''' SELECT table1.wpis_id @@ -925,7 +782,7 @@ }, { - 'id': 43, + 'id': 37, 'name': 'Spreadsheet ID 11', 'pattern': ''' SELECT historicoestatusrequisicion_id, requisicion_id, estatusrequisicion_id, @@ -948,7 +805,7 @@ }, { - 'id': 44, + 'id': 38, 'name': 'Spreadsheet ID 12', 'pattern': ''' SELECT po.id, @@ -978,7 +835,7 @@ }, { - 'id': 45, + 'id': 39, 'name': 'Spreadsheet ID 15', 'pattern': ''' SELECT * @@ -1022,7 +879,7 @@ }, { - 'id': 46, + 'id': 40, 'name': 'Spreadsheet ID 18', 'pattern': ''' SELECT DISTINCT ON (t.playerId) t.gzpId, t.pubCode, t.playerId, @@ -1061,7 +918,7 @@ }, { - 'id': 47, + 'id': 41, 'name': 'Spreadsheet ID 20', 'pattern': ''' SELECT * FROM (SELECT * FROM (SELECT NULL FROM EMP) WHERE N IS NULL) WHERE N IS NULL @@ -1072,7 +929,7 @@ }, { - 'id': 48, + 'id': 42, 'name': 'PostgreSQL Test', 'pattern': ''' SELECT "tweets"."latitude" AS "latitude", @@ -1080,9 +937,9 @@ FROM "public"."tweets" "tweets" WHERE (("tweets"."latitude" >= -90) AND ("tweets"."latitude" <= 80) AND ((("tweets"."longitude" >= -173.80000000000001) AND ("tweets"."longitude" <= 180)) OR ("tweets"."longitude" IS NULL)) - AND (CAST((DATE_TRUNC( \'day\', CAST("tweets"."created_at" AS DATE) ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL \'1 DAY\')) AS DATE) - = (TIMESTAMP \'2018-04-22 00:00:00.000\')) - AND (STRPOS(CAST(LOWER(CAST(CAST("tweets"."text" AS TEXT) AS TEXT)) AS TEXT),CAST(\'microsoft\' AS TEXT)) > 0)) + AND (CAST((DATE_TRUNC( 'day', CAST("tweets"."created_at" AS DATE) ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL '1 DAY')) AS DATE) + = (TIMESTAMP '2018-04-22 00:00:00.000')) + AND (STRPOS(CAST(LOWER(CAST(CAST("tweets"."text" AS TEXT) AS TEXT)) AS TEXT),CAST('microsoft' AS TEXT)) > 0)) GROUP BY 1, 2 ''', 'rewrite': ''' @@ -1091,23 +948,25 @@ FROM "public"."tweets" "tweets" WHERE (("tweets"."latitude" >= -90) AND ("tweets"."latitude" <= 80) AND ((("tweets"."longitude" >= -173.80000000000001) AND ("tweets"."longitude" <= 180)) OR ("tweets"."longitude" IS NULL)) - AND ((DATE_TRUNC( \'day\', "tweets"."created_at" ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL \'1 DAY\')) - = (TIMESTAMP \'2018-04-22 00:00:00.000\')) - AND "tweets"."text" ILIKE \'%microsoft%\') + AND ((DATE_TRUNC( 'day', "tweets"."created_at" ) + (-EXTRACT(DOW FROM "tweets"."created_at") * INTERVAL '1 DAY')) + = (TIMESTAMP '2018-04-22 00:00:00.000')) + AND "tweets"."text" ILIKE '%microsoft%') GROUP BY 1, 2 ''' }, { - 'id': 49, + 'id': 43, 'name': 'MySQL Test', - 'pattern': '''SELECT `tweets`.`latitude` AS `latitude`, + 'pattern': ''' +SELECT `tweets`.`latitude` AS `latitude`, `tweets`.`longitude` AS `longitude` FROM `tweets` WHERE ((ADDDATE(DATE_FORMAT(`tweets`.`created_at`, '%Y-%m-01 00:00:00'), INTERVAL 0 SECOND) = TIMESTAMP('2017-03-01 00:00:00')) AND (LOCATE('iphone', LOWER(`tweets`.`text`)) > 0)) GROUP BY 1, 2''', - 'rewrite': '''SELECT `tweets`.`latitude` AS `latitude`, + 'rewrite': ''' +SELECT `tweets`.`latitude` AS `latitude`, `tweets`.`longitude` AS `longitude` FROM `tweets` WHERE ((DATE_FORMAT(`tweets`.`created_at`, '%Y-%m-01 00:00:00') = TIMESTAMP('2017-03-01 00:00:00')) @@ -1118,4 +977,4 @@ def get_query(query_id: int) -> dict: - return next(filter(lambda x: x['id'] == query_id, queries), None) \ No newline at end of file + return next(filter(lambda x: x["id"] == query_id, queries), None) diff --git a/data/rules.py b/data/rules.py index d76c74e..4fb3bbd 100644 --- a/data/rules.py +++ b/data/rules.py @@ -19,7 +19,7 @@ # 'actions_json': '[]', # 'mapping': '{"x": "V1"}', 'database': 'postgresql', - 'examples': [17] + 'examples': [16] }, { @@ -36,7 +36,7 @@ # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\"}", 'database': 'postgresql', - 'examples': [1, 2, 48] + 'examples': [1, 2, 42] }, { @@ -48,7 +48,7 @@ 'rewrite': '', 'actions': '', 'database': 'postgresql', - 'examples': [48] + 'examples': [42] }, { @@ -65,7 +65,7 @@ # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\", \"y\": \"V2\"}", 'database': 'postgresql', - 'examples': [4, 48] + 'examples': [4, 42] }, { @@ -139,7 +139,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [10, 11, 12, 24] + 'examples': [9, 10, 11, 22] }, { @@ -164,7 +164,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [13, 14] + 'examples': [12, 13] }, { @@ -277,7 +277,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [18] + 'examples': [17] }, { @@ -289,7 +289,7 @@ 'rewrite': 'FROM ', 'actions': '', 'database': 'postgresql', - 'examples': [29] + 'examples': [27] }, { @@ -301,7 +301,7 @@ 'rewrite': 'SELECT . FROM INNER JOIN ON . = . WHERE . = ', 'actions': '', 'database': 'postgresql', - 'examples': [25, 30, 31, 32, 33] + 'examples': [23, 28, 31, 32, 29] }, { @@ -313,7 +313,7 @@ 'rewrite': 'SELECT <> FROM <> WHERE False', 'actions': '', 'database': 'postgresql', - 'examples': [26] + 'examples': [24] }, { @@ -335,7 +335,7 @@ AND <>''', 'actions': '', 'database': 'postgresql', - 'examples': [30] + 'examples': [28] }, { @@ -347,7 +347,7 @@ 'rewrite': '''SELECT ., . FROM (SELECT , DATE() FROM WHERE = ) AS GROUP BY <>, .''', 'actions': '', 'database': 'postgresql', - 'examples': [35] + 'examples': [31] }, { @@ -359,7 +359,7 @@ 'rewrite': '''SELECT <> FROM ((SELECT <> FROM WHERE LIMIT ) UNION (SELECT <> FROM WHERE EXISTS (SELECT FROM WHERE IN (, , , ) AND <>) LIMIT )) LIMIT ''', 'actions': '', 'database': 'postgresql', - 'examples': [36] + 'examples': [32] }, { @@ -371,7 +371,7 @@ 'rewrite': '''FALSE''', 'actions': '', 'database': 'postgresql', - 'examples': [37] + 'examples': [33] }, { @@ -407,7 +407,7 @@ 'rewrite': '''. IN ('', '', '')''', 'actions': '', 'database': 'postgresql', - 'examples': [40] + 'examples': [34] }, { @@ -419,7 +419,7 @@ 'rewrite': '''SELECT FROM WHERE <> GROUP BY ''', 'actions': '', 'database': 'postgresql', - 'examples': [41] + 'examples': [35] }, { @@ -431,7 +431,7 @@ 'rewrite': '''FROM INNER JOIN ON . = . WHERE <>''', 'actions': '', 'database': 'postgresql', - 'examples': [42] + 'examples': [36] }, { @@ -443,7 +443,7 @@ 'rewrite': '''SELECT ., ., ., ., ., . FROM JOIN ON . = . WHERE . = AND . = ORDER BY ., .''', 'actions': '', 'database': 'postgresql', - 'examples': [43] + 'examples': [37] }, { @@ -455,7 +455,7 @@ 'rewrite': '''SELECT <>, (SELECT FROM WHERE . = . GROUP BY ) AS FROM WHERE = ''', 'actions': '', 'database': 'postgresql', - 'examples': [44] + 'examples': [38] }, { @@ -467,7 +467,7 @@ 'rewrite': '''EXISTS (SELECT NULL FROM WHERE <> AND . = . AND EXISTS (SELECT NULL FROM WHERE <> AND (. = . OR . = .)))''', 'actions': '', 'database': 'postgresql', - 'examples': [45] + 'examples': [39] }, { @@ -479,7 +479,7 @@ 'rewrite': '''SELECT , , , COALESCE((SELECT . FROM WHERE <> AND <> LIMIT 1), ), (SELECT <> FROM WHERE <> AND . IN (, , , , , , ) LIMIT ) FROM WHERE AND ''', 'actions': '', 'database': 'postgresql', - 'examples': [46] + 'examples': [40] }, { @@ -491,7 +491,7 @@ 'rewrite': '''SELECT NULL FROM ''', 'actions': '', 'database': 'postgresql', - 'examples': [47] + 'examples': [41] }, { @@ -520,7 +520,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [15] + 'examples': [14] }, { @@ -557,7 +557,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [15] + 'examples': [14] }, { @@ -586,7 +586,7 @@ ''', 'actions': '', 'database': 'postgresql', - 'examples': [16] + 'examples': [15] }, # MySQL Rules @@ -605,7 +605,7 @@ # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\"}", 'database': 'mysql', - 'examples': [49] + 'examples': [43] }, { @@ -622,7 +622,7 @@ # 'actions_json': "[]", # 'mapping': "{\"x\": \"V1\", \"y\": \"V2\"}", 'database': 'mysql', - 'examples': [49] + 'examples': [43] }, { @@ -634,7 +634,7 @@ 'rewrite': 'SELECT <> FROM <> WHERE <> GROUP BY <>', 'actions': '', 'database': 'postgresql', - 'examples': [19] + 'examples': [18] }, { 'id': 2258, @@ -645,7 +645,7 @@ 'rewrite': ' IN (, )', 'actions': '', 'database': 'mysql', - 'examples': [20, 22, 23, 40] + 'examples': [19, 21, 23, 34] }, { 'id': 2280, @@ -656,7 +656,7 @@ 'rewrite': ' IN (, , )', 'actions': '', 'database': 'mysql', - 'examples': [34] + 'examples': [30] }, { 'id': 2259, @@ -667,7 +667,7 @@ 'rewrite': ' IN (<>, )', 'actions': '', 'database': 'mysql', - 'examples': [21] + 'examples': [20] }, { 'id': 2260, @@ -700,7 +700,7 @@ "rewrite": "SELECT DISTINCT , , , FROM , WHERE . = . AND <>", 'actions': '', 'database': 'mysql', - 'examples': [24] + 'examples': [22] }, { "id": 2263, @@ -711,7 +711,7 @@ "rewrite": "FROM ", 'actions': '', 'database': 'mysql', - 'examples': [27] + 'examples': [25] }, { "id": 2264, @@ -722,7 +722,7 @@ "rewrite": "FROM ", 'actions': '', 'database': 'mysql', - 'examples': [28] + 'examples': [26] }, { "id": 2265, diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index c768e49..f4e7e36 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,42 +1,20 @@ import mo_sql_parsing as mosql -from core.qb_parser import QBParser, parse_sql_to_qb_ast +from core.query_parser import QueryParser from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode ) from core.ast.node_type import NodeType from data.queries import get_query +parser = QueryParser() def test_parse_1(): - """ - SELECT clause: - - SUM() aggregate function - - CAST(column AS TEXT) function - - Column references (state_name) - - FROM clause: - - Simple table reference (tweets) - - WHERE clause: - - CAST() function with nested expressions - - DATE_TRUNC() function with string literal and column - - Nested CAST() functions - - IN operator with list of TIMESTAMP literals - - AND logical operator - - STRPOS() function with column and string literal - - Comparison operator (>) - - Numeric literal (0) - - GROUP BY clause: - - Numeric literal reference (2) - """ - query = get_query(1) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) # Check SELECT clause @@ -57,7 +35,7 @@ def test_parse_1(): break # assert from_clause is not None - # table_node = list(from_clause.children)[0] + # table_node = next(iter(from_clause.children)) # assert isinstance(table_node, TableNode) # assert table_node.name == "tweets" @@ -83,35 +61,23 @@ def test_parse_1(): def test_parse_2(): - """ - SELECT clause: - - SUM() aggregate function - - CAST(column AS TEXT) function - - Column references (state_name) - - FROM clause: - - Simple table reference (tweets) - - WHERE clause: - - DATE_TRUNC() function with string literal and column - - CAST() function with DATE type - - IN operator with list of TIMESTAMP literals - - AND logical operator - - STRPOS() function with column and string literal - - Comparison operator (>) - - Numeric literal (0) - - GROUP BY clause: - - Numeric literal reference (2) - """ - - query = get_query(2) + query = get_query(6) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Find the STRPOS condition in WHERE clause + # Check FROM clause has multiple tables + from_clause = None + for child in qb_ast.children: + if child.type == NodeType.FROM: + from_clause = child + break + + # assert from_clause is not None + # assert len(from_clause.children) == 2 + + # Check WHERE clause has multiple conditions where_clause = None for child in qb_ast.children: if child.type == NodeType.WHERE: @@ -119,257 +85,138 @@ def test_parse_2(): break # assert where_clause is not None - # condition = list(where_clause.children)[0] + # condition = next(iter(where_clause.children)) # assert isinstance(condition, OperatorNode) - # assert condition.name == "AND" - - # The condition should have two operands - # operands = list(condition.children) - # assert len(operands) == 2 def test_parse_3(): - """ - SELECT clause: - - SUM() aggregate function - - CAST(column AS TEXT) function - - Column references (state_name) - - FROM clause: - - Simple table reference (tweets) - - WHERE clause: - - DATE_TRUNC() function with column reference - - IN operator with list of TIMESTAMP literals - - AND logical operator - - STRPOS() function with column and string literal - - Comparison operator (>) - - Numeric literal (0) - - GROUP BY clause: - - Numeric literal reference (2) - """ - - query = get_query(3) + query = get_query(9) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check SELECT clause has 3 items - select_clause = None + # Check WHERE clause has IN with subquery + where_clause = None for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child + if child.type == NodeType.WHERE: + where_clause = child break - # assert select_clause is not None - # assert len(select_clause.children) == 3 + # assert where_clause is not None + # condition = next(iter(where_clause.children)) + # assert isinstance(condition, OperatorNode) + # assert condition.name == "AND" def test_parse_4(): - """ - SELECT clause: - - SUM() aggregate function - - CAST(column AS TEXT) function - - Column references (state_name) - - FROM clause: - - Simple table reference (tweets) - - WHERE clause: - - CAST() function with nested expressions - - DATE_TRUNC() function with string literal and column - - Nested CAST() functions - - IN operator with list of TIMESTAMP literals - - AND logical operator - - STRPOS() function with nested LOWER() function - - LOWER() function with column reference - - Comparison operator (>) - - Numeric literal (0) - - GROUP BY clause: - - Numeric literal reference (2) - """ - - query = get_query(4) + query = get_query(12) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check WHERE clause has IN condition - where_clause = None + # Check FROM clause has multiple JOINs + from_clause = None for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child + if child.type == NodeType.FROM: + from_clause = child break - # assert where_clause is not None - # condition = list(where_clause.children)[0] - # assert isinstance(condition, OperatorNode) - # assert condition.name == "IN" + # assert from_clause is not None + # Check for JOIN nodes in the FROM clause + # join_count = 0 + # for child in from_clause.children: + # if hasattr(child, 'type') and 'JOIN' in str(child.type): + # join_count += 1 + # assert join_count >= 2 def test_parse_5(): - """ - SELECT clause: - - SUM() aggregate function - - CAST(column AS TEXT) function - - Column references (state_name) - - FROM clause: - - Simple table reference (tweets) - - WHERE clause: - - DATE_TRUNC() function with string literal and column - - CAST() function with DATE type - - IN operator with list of TIMESTAMP literals - - AND logical operator - - ILIKE operator with wildcard pattern - - String literal with wildcards ('%iphone%') - - Column reference (text) - - GROUP BY clause: - - Numeric literal reference (2) - """ - - query = get_query(5) + query = get_query(16) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check FROM clause has two table references - from_clause = None + # Check SELECT clause has aggregation with subquery + select_clause = None for child in qb_ast.children: - if child.type == NodeType.FROM: - from_clause = child + if child.type == NodeType.SELECT: + select_clause = child break - # assert from_clause is not None - # assert len(from_clause.children) == 2 + # assert select_clause is not None + # assert len(select_clause.children) == 3 - # Both should be employees table with different aliases - # table_names = [table.name for table in from_clause.children] - # assert table_names.count("employees") == 2 + # Check for MAX function + # for child in select_clause.children: + # if isinstance(child, FunctionNode) and child.name == "MAX": + # assert True + # break def test_parse_6(): - """ - SELECT clause: - - Column references (e1.name, e1.age, e2.salary) - - Table alias usage in column references - - Multiple SELECT items - - FROM clause: - - Multiple table references with aliases - - Same table with different aliases (e1, e2) - - Table aliases (employee e1, employee e2) - - WHERE clause: - - AND logical operator - - Equality operator (=) - - Qualified column references (e1.id, e2.id) - - Self-join condition (e1.id = e2.id) - - Additional conditions (e1.age > 17, e2.salary > 35000) - - Comparison operators (>, >) - - Numeric literals (17, 35000) - """ - - query = get_query(6) + query = get_query(18) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check all expected clauses are present - clause_types = [child.type for child in qb_ast.children] - expected_clauses = [ - NodeType.SELECT, NodeType.FROM, NodeType.WHERE, - NodeType.GROUP_BY, NodeType.HAVING, NodeType.ORDER_BY, NodeType.LIMIT - ] + # Check SELECT clause has DISTINCT + select_clause = None + for child in qb_ast.children: + if child.type == NodeType.SELECT: + select_clause = child + break - # for expected_clause in expected_clauses: - # assert expected_clause in clause_types + # assert select_clause is not None + # Check for DISTINCT keyword + # assert hasattr(select_clause, 'distinct') and select_clause.distinct - # Check LIMIT value - limit_clause = None + # Check FROM clause has multiple tables + from_clause = None for child in qb_ast.children: - if child.type == NodeType.LIMIT: - limit_clause = child + if child.type == NodeType.FROM: + from_clause = child break - # assert limit_clause is not None - # assert limit_clause.limit == 10 + # assert from_clause is not None + # assert len(from_clause.children) == 2 def test_parse_7(): - """ - SELECT clause: - - Column references (e1.name, e1.age, e1.salary) - - Table alias usage in column references - - Multiple SELECT items - - FROM clause: - - Simple table reference with alias (employee e1) - - WHERE clause: - - AND logical operator - - Comparison operators (>, >) - - Qualified column references (e1.age, e1.salary) - - Numeric literals (17, 35000) - """ - - query = get_query(7) + query = get_query(25) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check SELECT clause has 3 items with CAST functions - select_clause = None + # Check WHERE clause has boolean logic + where_clause = None for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child + if child.type == NodeType.WHERE: + where_clause = child break - # assert select_clause is not None - # assert len(select_clause.children) == 3 - - # All SELECT items should be functions (CAST operations) - # function_count = sum(1 for item in select_clause.children if isinstance(item, FunctionNode)) - # assert function_count == 3 + # assert where_clause is not None + # condition = next(iter(where_clause.children)) + # assert isinstance(condition, OperatorNode) + # assert condition.name == "AND" def test_parse_8(): - """ - SELECT clause: - - Column reference (e1.age) - - Table alias usage in column reference - - FROM clause: - - Multiple table references with aliases - - Same table with different aliases (e1, e2) - - Table aliases (employee e1, employee e2) - - WHERE clause: - - AND logical operator - - Equality operator (=) - - Qualified column references (e1.id, e2.id) - - Self-join condition (e1.id = e2.id) - - Additional condition (e1.age > 17) - - Comparison operator (>) - - Numeric literal (17) - """ - - query = get_query(8) + query = get_query(29) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) + # assert isinstance(qb_ast, QueryNode) + + # Check for UNION structure + # This might be handled differently depending on parser implementation # assert isinstance(qb_ast, QueryNode) - # Check SELECT clause has date functions + # Check SELECT clause select_clause = None for child in qb_ast.children: if child.type == NodeType.SELECT: @@ -377,42 +224,16 @@ def test_parse_8(): break # assert select_clause is not None - # assert len(select_clause.children) == 3 - - # Check for DATE_TRUNC function - # function_names = [item.name for item in select_clause.children if isinstance(item, FunctionNode)] - # assert "DATE_TRUNC" in function_names def test_parse_9(): - """ - SELECT clause: - - Column references (e1.name, e1.age, e2.salary) - - Table alias usage in column references - - Multiple SELECT items - - FROM clause: - - Multiple table references with aliases - - Same table with different aliases (e1, e2) - - Table aliases (employee e1, employee e2) - - WHERE clause: - - AND logical operator - - Equality operator (=) - - Qualified column references (e1.id, e2.id) - - Self-join condition (e1.id = e2.id) - - Additional conditions (e1.age > 17, e2.salary > 35000) - - Comparison operators (>, >) - - Numeric literals (17, 35000) - """ - - query = get_query(9) + query = get_query(31) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check SELECT clause has string functions + # Check SELECT clause has complex aggregation select_clause = None for child in qb_ast.children: if child.type == NodeType.SELECT: @@ -420,45 +241,32 @@ def test_parse_9(): break # assert select_clause is not None - # assert len(select_clause.children) == 4 + # assert len(select_clause.children) == 3 + + # Check for CASE statement + # for child in select_clause.children: + # if isinstance(child, FunctionNode) and child.name == "CASE": + # assert True + # break - # Check for string functions - # function_names = [item.name for item in select_clause.children if isinstance(item, FunctionNode)] - # expected_functions = ["UPPER", "LOWER", "CONCAT", "SUBSTRING"] - # for func in expected_functions: - # assert func in function_names + # Check GROUP BY clause + group_by_clause = None + for child in qb_ast.children: + if child.type == NodeType.GROUP_BY: + group_by_clause = child + break + + # assert group_by_clause is not None def test_parse_10(): - """ - SELECT clause: - - Column references (empno, firstnme, lastname, phoneno) - - Multiple SELECT items - - FROM clause: - - Simple table reference (employee) - - WHERE clause: - - IN operator with subquery - - Column reference (workdept) - - AND logical operator - - Numeric literal (1) - - Subquery (within IN): - - SELECT clause with column reference (deptno) - - FROM clause with table reference (department) - - WHERE clause with equality condition - - Column reference (deptname) - - String literal ('OPERATIONS') - """ - - query = get_query(10) + query = get_query(42) sql = query['pattern'] - qb_ast = parse_sql_to_qb_ast(sql) + qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check SELECT clause has arithmetic operations + # Check SELECT clause select_clause = None for child in qb_ast.children: if child.type == NodeType.SELECT: @@ -466,8 +274,23 @@ def test_parse_10(): break # assert select_clause is not None - # assert len(select_clause.children) == 4 + # assert len(select_clause.children) == 2 + + # Check WHERE clause has complex conditions + where_clause = None + for child in qb_ast.children: + if child.type == NodeType.WHERE: + where_clause = child + break - # Check for arithmetic operators - # operator_count = sum(1 for item in select_clause.children if isinstance(item, OperatorNode)) - # assert operator_count >= 3 # Should have multiple arithmetic operations + # assert where_clause is not None + + # Check GROUP BY clause + group_by_clause = None + for child in qb_ast.children: + if child.type == NodeType.GROUP_BY: + group_by_clause = child + break + + # assert group_by_clause is not None + # assert len(group_by_clause.children) == 2 \ No newline at end of file From 8ba7d99dde4e524ad6e5345e39c418781241582f Mon Sep 17 00:00:00 2001 From: Yihong Yu <116992300+HazelYuAhiru@users.noreply.github.com> Date: Fri, 10 Oct 2025 20:47:25 -0400 Subject: [PATCH 4/4] update test --- tests/test_query_parser.py | 191 +++++++++++++++++++------------------ 1 file changed, 96 insertions(+), 95 deletions(-) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index f4e7e36..8b176f9 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -18,21 +18,22 @@ def test_parse_1(): # assert isinstance(qb_ast, QueryNode) # Check SELECT clause - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + + # select_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.SELECT: + # select_clause = child + # break # assert select_clause is not None # assert len(select_clause.children) == 2 # Check FROM clause - from_clause = None - for child in qb_ast.children: - if child.type == NodeType.FROM: - from_clause = child - break + # from_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.FROM: + # from_clause = child + # break # assert from_clause is not None # table_node = next(iter(from_clause.children)) @@ -40,21 +41,21 @@ def test_parse_1(): # assert table_node.name == "tweets" # Check WHERE clause - where_clause = None - for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child - break + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break # assert where_clause is not None # assert len(where_clause.children) == 1 # Check GROUP BY clause - group_by_clause = None - for child in qb_ast.children: - if child.type == NodeType.GROUP_BY: - group_by_clause = child - break + # group_by_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.GROUP_BY: + # group_by_clause = child + # break # assert group_by_clause is not None # assert len(group_by_clause.children) == 1 @@ -68,21 +69,21 @@ def test_parse_2(): # assert isinstance(qb_ast, QueryNode) # Check FROM clause has multiple tables - from_clause = None - for child in qb_ast.children: - if child.type == NodeType.FROM: - from_clause = child - break + # from_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.FROM: + # from_clause = child + # break # assert from_clause is not None # assert len(from_clause.children) == 2 # Check WHERE clause has multiple conditions - where_clause = None - for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child - break + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break # assert where_clause is not None # condition = next(iter(where_clause.children)) @@ -97,11 +98,11 @@ def test_parse_3(): # assert isinstance(qb_ast, QueryNode) # Check WHERE clause has IN with subquery - where_clause = None - for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child - break + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break # assert where_clause is not None # condition = next(iter(where_clause.children)) @@ -117,11 +118,11 @@ def test_parse_4(): # assert isinstance(qb_ast, QueryNode) # Check FROM clause has multiple JOINs - from_clause = None - for child in qb_ast.children: - if child.type == NodeType.FROM: - from_clause = child - break + # from_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.FROM: + # from_clause = child + # break # assert from_clause is not None # Check for JOIN nodes in the FROM clause @@ -140,11 +141,11 @@ def test_parse_5(): # assert isinstance(qb_ast, QueryNode) # Check SELECT clause has aggregation with subquery - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + # select_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.SELECT: + # select_clause = child + # break # assert select_clause is not None # assert len(select_clause.children) == 3 @@ -164,22 +165,22 @@ def test_parse_6(): # assert isinstance(qb_ast, QueryNode) # Check SELECT clause has DISTINCT - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + # select_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.SELECT: + # select_clause = child + # break # assert select_clause is not None # Check for DISTINCT keyword # assert hasattr(select_clause, 'distinct') and select_clause.distinct # Check FROM clause has multiple tables - from_clause = None - for child in qb_ast.children: - if child.type == NodeType.FROM: - from_clause = child - break + # from_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.FROM: + # from_clause = child + # break # assert from_clause is not None # assert len(from_clause.children) == 2 @@ -193,11 +194,11 @@ def test_parse_7(): # assert isinstance(qb_ast, QueryNode) # Check WHERE clause has boolean logic - where_clause = None - for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child - break + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break # assert where_clause is not None # condition = next(iter(where_clause.children)) @@ -212,18 +213,18 @@ def test_parse_8(): qb_ast = parser.parse(sql) # assert isinstance(qb_ast, QueryNode) - # Check for UNION structure - # This might be handled differently depending on parser implementation - # assert isinstance(qb_ast, QueryNode) + # Check for UNION operation (this query has UNION) + # Check if the query contains UNION + # assert 'UNION' in sql.upper() - # Check SELECT clause - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + # Check for subqueries in WHERE clause + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break - # assert select_clause is not None + # assert where_clause is not None def test_parse_9(): @@ -234,11 +235,11 @@ def test_parse_9(): # assert isinstance(qb_ast, QueryNode) # Check SELECT clause has complex aggregation - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + # select_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.SELECT: + # select_clause = child + # break # assert select_clause is not None # assert len(select_clause.children) == 3 @@ -250,11 +251,11 @@ def test_parse_9(): # break # Check GROUP BY clause - group_by_clause = None - for child in qb_ast.children: - if child.type == NodeType.GROUP_BY: - group_by_clause = child - break + # group_by_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.GROUP_BY: + # group_by_clause = child + # break # assert group_by_clause is not None @@ -267,30 +268,30 @@ def test_parse_10(): # assert isinstance(qb_ast, QueryNode) # Check SELECT clause - select_clause = None - for child in qb_ast.children: - if child.type == NodeType.SELECT: - select_clause = child - break + # select_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.SELECT: + # select_clause = child + # break # assert select_clause is not None # assert len(select_clause.children) == 2 # Check WHERE clause has complex conditions - where_clause = None - for child in qb_ast.children: - if child.type == NodeType.WHERE: - where_clause = child - break + # where_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.WHERE: + # where_clause = child + # break # assert where_clause is not None # Check GROUP BY clause - group_by_clause = None - for child in qb_ast.children: - if child.type == NodeType.GROUP_BY: - group_by_clause = child - break + # group_by_clause = None + # for child in qb_ast.children: + # if child.type == NodeType.GROUP_BY: + # group_by_clause = child + # break # assert group_by_clause is not None # assert len(group_by_clause.children) == 2 \ No newline at end of file