fix(mssql): apply limit and set alias for functions (#9644)

This commit is contained in:
Daniel Vaz Gaspar
2020-04-27 09:23:08 +01:00
committed by GitHub
parent 5e4c291913
commit 516bdf6db1
3 changed files with 120 additions and 3 deletions

View File

@@ -14,13 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import ParsedQuery
if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import
logger = logging.getLogger(__name__)
class MssqlEngineSpec(BaseEngineSpec):
@@ -76,3 +83,8 @@ class MssqlEngineSpec(BaseEngineSpec):
if regex.match(type_):
return sqla_type
return None
@classmethod
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
new_sql = ParsedQuery(sql).set_alias()
return super().apply_limit_to_sql(new_sql, limit, database)

View File

@@ -18,7 +18,14 @@ import logging
from typing import List, Optional, Set
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
@@ -247,3 +254,39 @@ class ParsedQuery:
for i in statement.tokens:
str_res += str(i.value)
return str_res
def set_alias(self) -> str:
"""
Returns a new query string where all functions have alias.
This is particularly necessary for MSSQL engines.
:return: String with new aliased SQL query
"""
new_sql = ""
changed_counter = 1
for token in self._parsed[0].tokens:
# Identifier list (list of columns)
if isinstance(token, IdentifierList) and token.ttype is None:
for i, identifier in enumerate(token.get_identifiers()):
# Functions are anonymous on MSSQL
if isinstance(identifier, Function) and not identifier.has_alias():
identifier.value = (
f"{identifier.value} AS"
f" {identifier.get_real_name()}_{changed_counter}"
)
changed_counter += 1
new_sql += str(identifier.value)
# If not last identifier
if i != len(list(token.get_identifiers())) - 1:
new_sql += ", "
# Just a lonely function?
elif isinstance(token, Function) and token.ttype is None:
if not token.has_alias():
token.value = (
f"{token.value} AS {token.get_real_name()}_{changed_counter}"
)
new_sql += str(token.value)
# Nothing to change, assemble what we have
else:
new_sql += str(token.value)
return new_sql

View File

@@ -15,15 +15,18 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from typing import Optional
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
from sqlalchemy.sql import select, Select
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.extensions import db
from superset.models.core import Database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
@@ -94,6 +97,65 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
for actual, expected in test_cases:
self.assertEqual(actual, expected)
def test_apply_limit(self):
def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
return str(
qry.compile(
dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
)
)
database = Database(
database_name="mssql_test",
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
)
db.session.add(database)
db.session.commit()
with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
"AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, "
"FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
expected_sql = (
"SELECT TOP 1000 * \n"
"FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
" AS inner_qry"
)
self.assertEqual(expected_sql, limited_sql)
db.session.delete(database)
db.session.commit()
@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)