diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 6a231b255f1..4fc6e6f44dc 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -14,13 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import logging import re from datetime import datetime -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, TYPE_CHECKING from sqlalchemy.types import String, TypeEngine, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.sql_parse import ParsedQuery + +if TYPE_CHECKING: + from superset.models.core import Database # pylint: disable=unused-import + +logger = logging.getLogger(__name__) class MssqlEngineSpec(BaseEngineSpec): @@ -76,3 +83,8 @@ class MssqlEngineSpec(BaseEngineSpec): if regex.match(type_): return sqla_type return None + + @classmethod + def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str: + new_sql = ParsedQuery(sql).set_alias() + return super().apply_limit_to_sql(new_sql, limit, database) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index b39fc4412aa..8cac2ffdab8 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -18,7 +18,14 @@ import logging from typing import List, Optional, Set import sqlparse -from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList +from sqlparse.sql import ( + Function, + Identifier, + IdentifierList, + remove_quotes, + Token, + TokenList, +) from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt @@ -247,3 +254,39 @@ class ParsedQuery: for i in statement.tokens: str_res += str(i.value) return str_res + + def set_alias(self) -> str: + """ + Returns a new query string where all functions have alias. + This is particularly necessary for MSSQL engines. + + :return: String with new aliased SQL query + """ + new_sql = "" + changed_counter = 1 + for token in self._parsed[0].tokens: + # Identifier list (list of columns) + if isinstance(token, IdentifierList) and token.ttype is None: + for i, identifier in enumerate(token.get_identifiers()): + # Functions are anonymous on MSSQL + if isinstance(identifier, Function) and not identifier.has_alias(): + identifier.value = ( + f"{identifier.value} AS" + f" {identifier.get_real_name()}_{changed_counter}" + ) + changed_counter += 1 + new_sql += str(identifier.value) + # If not last identifier + if i != len(list(token.get_identifiers())) - 1: + new_sql += ", " + # Just a lonely function? + elif isinstance(token, Function) and token.ttype is None: + if not token.has_alias(): + token.value = ( + f"{token.value} AS {token.get_real_name()}_{changed_counter}" + ) + new_sql += str(token.value) + # Nothing to change, assemble what we have + else: + new_sql += str(token.value) + return new_sql diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py index 238dd2a678e..9f5351cc939 100644 --- a/tests/db_engine_specs/mssql_tests.py +++ b/tests/db_engine_specs/mssql_tests.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. import unittest.mock as mock +from typing import Optional from sqlalchemy import column, table from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR -from sqlalchemy.sql import select +from sqlalchemy.sql import select, Select from sqlalchemy.types import String, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec +from superset.extensions import db +from superset.models.core import Database from tests.db_engine_specs.base_tests import DbEngineSpecTestCase @@ -94,6 +97,65 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase): for actual, expected in test_cases: self.assertEqual(actual, expected) + def test_apply_limit(self): + def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str: + return str( + qry.compile( + dialect=mssql.dialect(), compile_kwargs={"literal_binds": True} + ) + ) + + database = Database( + database_name="mssql_test", + sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb", + ) + db.session.add(database) + db.session.commit() + + with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query): + test_sql = "SELECT COUNT(*) FROM FOO_TABLE" + + limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) + + expected_sql = ( + "SELECT TOP 1000 * \n" + "FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry" + ) + self.assertEqual(expected_sql, limited_sql) + + test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE" + limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) + + expected_sql = ( + "SELECT TOP 1000 * \n" + "FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) " + "AS inner_qry" + ) + self.assertEqual(expected_sql, limited_sql) + + test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1" + limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) + + expected_sql = ( + "SELECT TOP 1000 * \n" + "FROM (SELECT COUNT(*) AS COUNT_1, " + "FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)" + " AS inner_qry" + ) + self.assertEqual(expected_sql, limited_sql) + + test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE" + limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database) + expected_sql = ( + "SELECT TOP 1000 * \n" + "FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)" + " AS inner_qry" + ) + self.assertEqual(expected_sql, limited_sql) + + db.session.delete(database) + db.session.commit() + @mock.patch.object( MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted" )