Skip to content

Commit 482dc3a

Browse files
replace sqlparse.format with sqlglot.transpile
1 parent 62dfeb1 commit 482dc3a

File tree

6 files changed

+44
-40
lines changed

6 files changed

+44
-40
lines changed

superset/common/query_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def validate(
276276
# performance for WHERE ... IN (...) clauses
277277
# Clauses are anyway checked for their validity in
278278
# e.g., connectors/sqla/models/get_query_str_extended
279-
# self._sanitize_filters()
279+
self._sanitize_filters()
280280
return None
281281
except QueryObjectValidationError as ex:
282282
if raise_exceptions:

superset/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1900,7 +1900,7 @@ class ExtraDynamicQueryFilters(TypedDict, total=False):
19001900
elif importlib.util.find_spec("superset_config") and not is_test():
19011901
try:
19021902
# pylint: disable=import-error,wildcard-import,unused-wildcard-import
1903-
import superset_config
1903+
import superset_config as superset_config
19041904
from superset_config import * # noqa: F403, F401
19051905

19061906
click.secho(

superset/db_engine_specs/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
import pandas as pd
4040
import requests
41-
import sqlparse
41+
import sqlglot
4242
from apispec import APISpec
4343
from apispec.ext.marshmallow import MarshmallowPlugin
4444
from deprecation import deprecated
@@ -1236,7 +1236,7 @@ def get_cte_query(cls, sql: str) -> str | None:
12361236
12371237
"""
12381238
if not cls.allows_cte_in_subquery:
1239-
stmt = sqlparse.parse(sql)[0]
1239+
stmt = sqlglot.tokenize(sql)
12401240

12411241
# The first meaningful token for CTE will be with WITH
12421242
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
@@ -2158,7 +2158,8 @@ def cancel_query( # pylint: disable=unused-argument
21582158

21592159
@classmethod
21602160
def parse_sql(cls, sql: str) -> list[str]:
2161-
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
2161+
return sqlglot.transpile(sql)
2162+
# return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
21622163

21632164
@classmethod
21642165
def get_impersonation_key(cls, user: User | None) -> Any:

superset/models/core.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from sqlalchemy.pool import NullPool
5959
from sqlalchemy.schema import UniqueConstraint
6060
from sqlalchemy.sql import ColumnElement, expression, Select
61-
from sqlglot import parse
6261

6362
from superset import app, db_engine_specs, is_feature_enabled
6463
from superset.commands.database.exceptions import DatabaseInvalidError
@@ -654,11 +653,7 @@ def get_df( # pylint: disable=too-many-locals
654653
schema: str | None = None,
655654
mutator: Callable[[pd.DataFrame], None] | None = None,
656655
) -> pd.DataFrame:
657-
# before we split sqls using sql parse, however this core code is only reachable
658-
# with single sql queries. Thus, we remove the engine spec parser here
659-
# sqls = self.db_engine_spec.parse_sql(sql)
660-
sqls = parse(sql)
661-
656+
sqls = self.db_engine_spec.parse_sql(sql)
662657
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
663658
engine_url = engine.url
664659

superset/sql_parse.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from collections.abc import Iterator
2525
from typing import Any, cast, TYPE_CHECKING
2626

27+
import sqlglot
2728
import sqlparse
2829
from flask_babel import gettext as __
2930
from jinja2 import nodes
3031
from sqlalchemy import and_
3132
from sqlglot.dialects.dialect import Dialects
33+
from sqlglot.errors import ParseError
3234
from sqlparse import keywords
3335
from sqlparse.lexer import Lexer
3436
from sqlparse.sql import (
@@ -42,7 +44,6 @@
4244
Where,
4345
)
4446
from sqlparse.tokens import (
45-
Comment,
4647
CTE,
4748
DDL,
4849
DML,
@@ -257,6 +258,7 @@ def __init__(
257258
sql_statement: str,
258259
engine: str = "base",
259260
):
261+
sql_statement = sqlglot.transpile(sql_statement)
260262
self.sql: str = sql_statement
261263
self._engine = engine
262264
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
@@ -579,30 +581,35 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
579581

580582
def sanitize_clause(clause: str) -> str:
581583
# clause = sqlparse.format(clause, strip_comments=True)
582-
statements = sqlparse.parse(clause)
584+
try:
585+
statements = sqlglot.transpile(clause, pretty=True)
586+
except Exception as p_err:
587+
if isinstance(p_err, ParseError):
588+
raise QueryClauseValidationException(str(p_err)) from p_err
589+
raise ValueError(str(p_err)) from None
583590
if len(statements) != 1:
584591
raise QueryClauseValidationException("Clause contains multiple statements")
585-
open_parens = 0
586-
587-
previous_token = None
588-
for token in statements[0]:
589-
if token.value == "/" and previous_token and previous_token.value == "*":
590-
raise QueryClauseValidationException("Closing unopened multiline comment")
591-
if token.value == "*" and previous_token and previous_token.value == "/":
592-
raise QueryClauseValidationException("Unclosed multiline comment")
593-
if token.value in (")", "("):
594-
open_parens += 1 if token.value == "(" else -1
595-
if open_parens < 0:
596-
raise QueryClauseValidationException(
597-
"Closing unclosed parenthesis in filter clause"
598-
)
599-
previous_token = token
600-
if open_parens > 0:
601-
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
602-
603-
if previous_token and previous_token.ttype in Comment:
604-
if previous_token.value[-1] != "\n":
605-
clause = f"{clause}\n"
592+
# open_parens = 0
593+
594+
# previous_token = None
595+
# for token in statements[0]:
596+
# if token.value == "/" and previous_token and previous_token.value == "*":
597+
# raise QueryClauseValidationException("Closing unopened multiline comment")
598+
# if token.value == "*" and previous_token and previous_token.value == "/":
599+
# raise QueryClauseValidationException("Unclosed multiline comment")
600+
# if token.value in (")", "("):
601+
# open_parens += 1 if token.value == "(" else -1
602+
# if open_parens < 0:
603+
# raise QueryClauseValidationException(
604+
# "Closing unclosed parenthesis in filter clause"
605+
# )
606+
# previous_token = token
607+
# if open_parens > 0:
608+
# raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
609+
610+
# if previous_token and previous_token.ttype in Comment:
611+
# if previous_token.value[-1] != "\n":
612+
# clause = f"{clause}\n"
606613

607614
return clause
608615

tests/unit_tests/sql_parse_tests.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest.mock import Mock
2121

2222
import pytest
23+
import sqlglot
2324
import sqlparse
2425
from pytest_mock import MockerFixture
2526
from sqlalchemy import text
@@ -1168,17 +1169,17 @@ def test_messy_breakdown_statements() -> None:
11681169
]
11691170

11701171

1171-
def test_sqlparse_formatting():
1172+
def test_sqlglot_formatting():
11721173
"""
11731174
Test that ``from_unixtime`` is formatted correctly.
11741175
"""
1175-
assert sqlparse.format(
1176+
assert sqlglot.transpile(
11761177
"SELECT extract(HOUR from from_unixtime(hour_ts) "
11771178
"AT TIME ZONE 'America/Los_Angeles') from table",
1178-
reindent=True,
1179-
) == (
1180-
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
1181-
"AT TIME ZONE 'America/Los_Angeles')\nfrom table"
1179+
pretty=True,
1180+
)[0] == (
1181+
"SELECT\n EXTRACT(HOUR FROM FROM_UNIXTIME(hour_ts) AT TIME ZONE 'America/Los_Angeles')"
1182+
"\nFROM table"
11821183
)
11831184

11841185

0 commit comments

Comments
 (0)