mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(mssql): apply limit and set alias for functions (#9644)
This commit is contained in:
committed by
GitHub
parent
5e4c291913
commit
516bdf6db1
@@ -14,13 +14,20 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database # pylint: disable=unused-import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MssqlEngineSpec(BaseEngineSpec):
|
||||
@@ -76,3 +83,8 @@ class MssqlEngineSpec(BaseEngineSpec):
|
||||
if regex.match(type_):
|
||||
return sqla_type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
|
||||
new_sql = ParsedQuery(sql).set_alias()
|
||||
return super().apply_limit_to_sql(new_sql, limit, database)
|
||||
|
||||
@@ -18,7 +18,14 @@ import logging
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
remove_quotes,
|
||||
Token,
|
||||
TokenList,
|
||||
)
|
||||
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.utils import imt
|
||||
|
||||
@@ -247,3 +254,39 @@ class ParsedQuery:
|
||||
for i in statement.tokens:
|
||||
str_res += str(i.value)
|
||||
return str_res
|
||||
|
||||
def set_alias(self) -> str:
|
||||
"""
|
||||
Returns a new query string where all functions have alias.
|
||||
This is particularly necessary for MSSQL engines.
|
||||
|
||||
:return: String with new aliased SQL query
|
||||
"""
|
||||
new_sql = ""
|
||||
changed_counter = 1
|
||||
for token in self._parsed[0].tokens:
|
||||
# Identifier list (list of columns)
|
||||
if isinstance(token, IdentifierList) and token.ttype is None:
|
||||
for i, identifier in enumerate(token.get_identifiers()):
|
||||
# Functions are anonymous on MSSQL
|
||||
if isinstance(identifier, Function) and not identifier.has_alias():
|
||||
identifier.value = (
|
||||
f"{identifier.value} AS"
|
||||
f" {identifier.get_real_name()}_{changed_counter}"
|
||||
)
|
||||
changed_counter += 1
|
||||
new_sql += str(identifier.value)
|
||||
# If not last identifier
|
||||
if i != len(list(token.get_identifiers())) - 1:
|
||||
new_sql += ", "
|
||||
# Just a lonely function?
|
||||
elif isinstance(token, Function) and token.ttype is None:
|
||||
if not token.has_alias():
|
||||
token.value = (
|
||||
f"{token.value} AS {token.get_real_name()}_{changed_counter}"
|
||||
)
|
||||
new_sql += str(token.value)
|
||||
# Nothing to change, assemble what we have
|
||||
else:
|
||||
new_sql += str(token.value)
|
||||
return new_sql
|
||||
|
||||
@@ -15,15 +15,18 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest.mock as mock
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import column, table
|
||||
from sqlalchemy.dialects import mssql
|
||||
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql import select, Select
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
|
||||
|
||||
|
||||
@@ -94,6 +97,65 @@ class MssqlEngineSpecTest(DbEngineSpecTestCase):
|
||||
for actual, expected in test_cases:
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_apply_limit(self):
|
||||
def compile_sqla_query(qry: Select, schema: Optional[str] = None) -> str:
|
||||
return str(
|
||||
qry.compile(
|
||||
dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
|
||||
)
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="mssql_test",
|
||||
sqlalchemy_uri="mssql+pymssql://sa:Password_123@localhost:1433/msdb",
|
||||
)
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
with mock.patch.object(database, "compile_sqla_query", new=compile_sqla_query):
|
||||
test_sql = "SELECT COUNT(*) FROM FOO_TABLE"
|
||||
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1 FROM FOO_TABLE) AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), SUM(id) FROM FOO_TABLE"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, SUM(id) AS SUM_2 FROM FOO_TABLE) "
|
||||
"AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, "
|
||||
"FOO_COL1 FROM FOO_TABLE GROUP BY FOO_COL1)"
|
||||
" AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
test_sql = "SELECT COUNT(*), COUNT(*) FROM FOO_TABLE"
|
||||
limited_sql = MssqlEngineSpec.apply_limit_to_sql(test_sql, 1000, database)
|
||||
expected_sql = (
|
||||
"SELECT TOP 1000 * \n"
|
||||
"FROM (SELECT COUNT(*) AS COUNT_1, COUNT(*) AS COUNT_2 FROM FOO_TABLE)"
|
||||
" AS inner_qry"
|
||||
)
|
||||
self.assertEqual(expected_sql, limited_sql)
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch.object(
|
||||
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user