diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index be2b3ee7d1a..7ea5bc6a2a1 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -39,7 +39,6 @@ from uuid import uuid4 import pandas as pd import requests -import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from deprecation import deprecated @@ -57,14 +56,19 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import literal_column, quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextClause from sqlalchemy.types import TypeEngine -from sqlparse.tokens import CTE from superset import db, sql_parse from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError -from superset.sql.parse import BaseSQLStatement, LimitMethod, SQLScript, Table +from superset.sql.parse import ( + BaseSQLStatement, + LimitMethod, + SQLScript, + SQLStatement, + Table, +) from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, @@ -1124,18 +1128,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ if not cls.allows_cte_in_subquery: - stmt = sqlparse.parse(sql)[0] - - # The first meaningful token for CTE will be with WITH - idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) - if not (token and token.ttype == CTE): - return None - idx, token = stmt.token_next(idx) - idx = stmt.token_index(token) + 1 - - # extract rest of the SQLs after CTE - remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() - return f"WITH {token.value},\n{cls.cte_alias} AS (\n{remainder}\n)" + statement = SQLStatement(sql, engine=cls.engine) + if statement.has_cte(): + return statement.as_cte(cls.cte_alias).format() return None diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 99b8ab60d82..0cada5f0e5c 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -276,6 +276,23 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def has_cte(self) -> bool: + """ + Check if the statement has a CTE. + + :return: True if the statement has a CTE at the top level. + """ + raise NotImplementedError() + + def as_cte(self, alias: str = "__cte") -> SQLStatement: + """ + Rewrite the statement as a CTE. + + :param alias: The alias to use for the CTE. + :return: A new SQLStatement with the CTE. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -526,6 +543,36 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): else: # method == LimitMethod.FETCH_MANY pass + def has_cte(self) -> bool: + """ + Check if the statement has a CTE. + + :return: True if the statement has a CTE at the top level. + """ + return "with" in self._parsed.args + + def as_cte(self, alias: str = "__cte") -> SQLStatement: + """ + Rewrite the statement as a CTE. + + This is needed by MS SQL when the query includes CTEs. In that case the CTEs + need to be moved to the top of the query when we wrap it as a subquery when + building charts. + + :param alias: The alias to use for the CTE. + :return: A new SQLStatement with the CTE. + """ + existing_ctes = self._parsed.args["with"].expressions if self.has_cte() else [] + self._parsed.args["with"] = None + new_cte = exp.CTE( + this=self._parsed.copy(), + alias=exp.TableAlias(this=exp.Identifier(this=alias)), + ) + return SQLStatement( + ast=exp.With(expressions=[*existing_ctes, new_cte], this=None), + engine=self.engine, + ) + class KQLSplitState(enum.Enum): """ diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 62832c4f3cf..e0ce5e1180c 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -217,17 +217,21 @@ select 'EUR' as cur select * from currency union all select * from currency_2 """ ), - dedent( - """WITH currency as ( -select 'INR' as cur -), -currency_2 as ( -select 'EUR' as cur -), -__cte AS ( -select * from currency union all select * from currency_2 -)""" - ), + """WITH currency AS ( + SELECT + 'INR' AS cur +), currency_2 AS ( + SELECT + 'EUR' AS cur +), __cte AS ( + SELECT + * + FROM currency + UNION ALL + SELECT + * + FROM currency_2 +)""", ), ( "SELECT 1 as cnt", diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 3d0746b5c79..d750870c987 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1519,3 +1519,109 @@ def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None: statement = KustoKQLStatement(kql, "kustokql") statement.set_limit_value(limit) assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT 1", "postgresql", False), + ("SELECT 1 AS cnt", "postgresql", False), + ( + """ +SELECT 'INR' AS cur +UNION +SELECT 'USD' AS cur +UNION +SELECT 'EUR' AS cur + """, + "postgresql", + False, + ), + ("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True), + ( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM t2), + z AS (SELECT b AS c FROM t3) +SELECT c FROM z + """, + "postgresql", + True, + ), + ( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM x), + z AS (SELECT b AS c FROM y) +SELECT c FROM z + """, + "postgresql", + True, + ), + ( + """ +WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) +AS ( + SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear + FROM SalesOrderHeader + WHERE SalesPersonID IS NOT NULL +) +SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear +FROM CTE__test +GROUP BY SalesYear, SalesPersonID +ORDER BY SalesPersonID, SalesYear; + """, + "postgresql", + True, + ), + ], +) +def test_has_cte(sql: str, engine: str, expected: bool) -> None: + """ + Test that the parser detects CTEs correctly. + """ + assert SQLStatement(sql, engine).has_cte() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ( + "SELECT 1", + "postgresql", + "WITH __cte AS (\n SELECT\n 1\n)", + ), + ( + """ +WITH currency AS (SELECT 'INR' AS cur), + currency_2 AS (SELECT 'USD' AS cur) +SELECT * FROM currency +UNION ALL +SELECT * FROM currency_2 + """, + "postgresql", + """WITH currency AS ( + SELECT + 'INR' AS cur +), currency_2 AS ( + SELECT + 'USD' AS cur +), __cte AS ( + SELECT + * + FROM currency + UNION ALL + SELECT + * + FROM currency_2 +)""", + ), + ], +) +def test_as_cte(sql: str, engine: str, expected: str) -> None: + """ + Test that we can covert select to CTE. + """ + assert SQLStatement(sql, engine).as_cte().format() == expected