diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 004b2dd53dc..70526ea12ee 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -55,7 +55,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import literal_column, quoted_name, text -from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause +from sqlalchemy.sql.expression import ColumnClause, Select, TextClause from sqlalchemy.types import TypeEngine from sqlparse.tokens import CTE @@ -64,7 +64,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError -from superset.sql.parse import BaseSQLStatement, SQLScript, Table +from superset.sql.parse import BaseSQLStatement, LimitMethod, SQLScript, Table from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, @@ -165,14 +165,6 @@ def compile_timegrain_expression( return element.name.replace("{col}", compiler.process(element.col, **kwargs)) -class LimitMethod: # pylint: disable=too-few-public-methods - """Enum the ways that limits can be applied""" - - FETCH_MANY = "fetch_many" - WRAP_SQL = "wrap_sql" - FORCE_LIMIT = "force_limit" - - class MetricType(TypedDict, total=False): """ Type for metrics return by `get_metrics`. @@ -376,16 +368,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods allows_cte_in_subquery = True # Define alias for CTE cte_alias = "__cte" - # Whether allow LIMIT clause in the SQL - # If True, then the database engine is allowed for LIMIT clause - # If False, then the database engine is allowed for TOP clause - allow_limit_clause = True # This set will give keywords for select statements # to consider for the engines with TOP SQL parsing select_keywords: set[str] = {"SELECT"} - # This set will give the keywords for data limit statements - # to consider for the engines with TOP SQL parsing - top_keywords: set[str] = {"TOP"} # A set of disallowed connection query parameters by driver name disallow_uri_query_params: dict[str, set[str]] = {} # A Dict of query parameters that will always be used on every connection @@ -1118,100 +1103,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return {} - @classmethod - def apply_limit_to_sql( - cls, sql: str, limit: int, database: Database, force: bool = False - ) -> str: - """ - Alters the SQL statement to apply a LIMIT clause - - :param sql: SQL query - :param limit: Maximum number of rows to be returned by the query - :param database: Database instance - :return: SQL query with limit clause - """ - if cls.limit_method == LimitMethod.WRAP_SQL: - sql = sql.strip("\t\n ;") - qry = ( - select("*") - .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")) - .limit(limit) - ) - return database.compile_sqla_query(qry) - - if cls.limit_method == LimitMethod.FORCE_LIMIT: - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - sql = parsed_query.set_or_update_query_limit(limit, force=force) - - return sql - - @classmethod - def apply_top_to_sql(cls, sql: str, limit: int) -> str: # noqa: C901 - """ - Alters the SQL statement to apply a TOP clause - :param limit: Maximum number of rows to be returned by the query - :param sql: SQL query - :return: SQL query with top clause - """ - - cte = None - sql_remainder = None - sql = sql.strip(" \t\n;") - query_limit: int | None = sql_parse.extract_top_from_query( - sql, cls.top_keywords - ) - if not limit: - final_limit = query_limit - elif int(query_limit or 0) < limit and query_limit is not None: - final_limit = query_limit - else: - final_limit = limit - if not cls.allows_cte_in_subquery: - cte, sql_remainder = sql_parse.get_cte_remainder_query(sql) - if cte: - str_statement = str(sql_remainder) - cte = cte + "\n" - else: - cte = "" - str_statement = str(sql) - str_statement = str_statement.replace("\n", " ").replace("\r", "") - - tokens = str_statement.rstrip().split(" ") - tokens = [token for token in tokens if token] - if cls.top_not_in_sql(str_statement): - selects = [ - i - for i, word in enumerate(tokens) - if word.upper() in cls.select_keywords - ] - first_select = selects[0] - if tokens[first_select + 1].upper() == "DISTINCT": - first_select += 1 - - tokens.insert(first_select + 1, "TOP") - tokens.insert(first_select + 2, str(final_limit)) - - next_is_limit_token = False - new_tokens = [] - - for token in tokens: - if token in cls.top_keywords: - next_is_limit_token = True - elif next_is_limit_token: - if token.isdigit(): - token = str(final_limit) - next_is_limit_token = False - new_tokens.append(token) - sql = " ".join(new_tokens) - return cte + sql - - @classmethod - def top_not_in_sql(cls, sql: str) -> bool: - for top_word in cls.top_keywords: - if top_word.upper() in sql.upper(): - return False - return True - @classmethod def get_limit_from_sql(cls, sql: str) -> int | None: """ @@ -1223,18 +1114,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods script = SQLScript(sql, engine=cls.engine) return script.statements[-1].get_limit_value() - @classmethod - def set_or_update_query_limit(cls, sql: str, limit: int) -> str: - """ - Create a query based on original query but with new limit clause - - :param sql: SQL query - :param limit: New limit to insert/replace into query - :return: Query with new limit - """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - return parsed_query.set_or_update_query_limit(limit) - @classmethod def get_cte_query(cls, sql: str) -> str | None: """ @@ -1685,8 +1564,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods full_table_name = cls.quote_table(table, engine.dialect) qry = select(fields).select_from(text(full_table_name)) - if limit and cls.allow_limit_clause: - qry = qry.limit(limit) + qry = qry.limit(limit) if latest_partition: partition_query = cls.where_latest_partition( database, diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index 6781701ac79..8dd7b00a6b9 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -20,9 +20,9 @@ from typing import Optional, Union from sqlalchemy.engine.reflection import Inspector from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import Database -from superset.sql_parse import Table +from superset.sql.parse import LimitMethod, Table logger = logging.getLogger(__name__) diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py index 15c4bef7bf4..d8222d81d4f 100644 --- a/superset/db_engine_specs/firebird.py +++ b/superset/db_engine_specs/firebird.py @@ -20,7 +20,8 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec +from superset.sql.parse import LimitMethod class FirebirdEngineSpec(BaseEngineSpec): diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py index 13b5674c87a..3ae70349eca 100644 --- a/superset/db_engine_specs/hana.py +++ b/superset/db_engine_specs/hana.py @@ -20,8 +20,8 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import LimitMethod from superset.db_engine_specs.postgres import PostgresBaseEngineSpec +from superset.sql.parse import LimitMethod class HanaEngineSpec(PostgresBaseEngineSpec): diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index a48c83182a5..2081f6c89ce 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -22,12 +22,13 @@ from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import ( SupersetDBAPIDatabaseError, SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) +from superset.sql.parse import LimitMethod from superset.utils.core import GenericDataType @@ -105,7 +106,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method - limit_method = LimitMethod.WRAP_SQL engine = "kustokql" engine_name = "KustoKQL" time_groupby_inline = True diff --git a/superset/db_engine_specs/lib.py b/superset/db_engine_specs/lib.py index 106b9c75503..e66bcccc353 100644 --- a/superset/db_engine_specs/lib.py +++ b/superset/db_engine_specs/lib.py @@ -119,7 +119,7 @@ def diagnose(spec: type[BaseEngineSpec]) -> dict[str, Any]: output.update( { "module": spec.__module__, - "limit_method": spec.limit_method.upper(), + "limit_method": spec.limit_method.value, "joins": spec.allows_joins, "subqueries": spec.allows_subqueries, "alias_in_select": spec.allows_alias_in_select, @@ -129,7 +129,6 @@ def diagnose(spec: type[BaseEngineSpec]) -> dict[str, Any]: "order_by_not_in_select": spec.allows_hidden_orderby_agg, "expressions_in_orderby": spec.allows_hidden_cc_in_orderby, "cte_in_subquery": spec.allows_cte_in_subquery, - "limit_clause": spec.allow_limit_clause, "max_column_name": spec.max_column_name_length, "sql_comments": spec.allows_sql_comments, "escaped_colons": spec.allows_escaped_colons, @@ -223,7 +222,7 @@ def generate_table() -> list[list[Any]]: rows = [] # pylint: disable=redefined-outer-name rows.append(["Feature"] + list(info)) # header row - rows.append(["Module"] + list(db_info["module"] for db_info in info.values())) # noqa: C400 + rows.append(["Module"] + [db_info["module"] for db_info in info.values()]) # descriptive keys = [ @@ -244,14 +243,14 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [DATABASE_DETAILS[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 + [DATABASE_DETAILS[key]] + [db_info[key] for db_info in info.values()] ) # basic for time_grain in TimeGrain: rows.append( [f"Has time grain {time_grain.name}"] - + list(db_info["time_grains"][time_grain.name] for db_info in info.values()) # noqa: C400 + + [db_info["time_grains"][time_grain.name] for db_info in info.values()] ) keys = [ "masked_encrypted_extra", @@ -259,9 +258,7 @@ def generate_table() -> list[list[Any]]: "function_names", ] for key in keys: - rows.append( - [BASIC_FEATURES[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 - ) + rows.append([BASIC_FEATURES[key]] + [db_info[key] for db_info in info.values()]) # nice to have keys = [ @@ -280,8 +277,7 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [NICE_TO_HAVE_FEATURES[key]] - + list(db_info[key] for db_info in info.values()) # noqa: C400 + [NICE_TO_HAVE_FEATURES[key]] + [db_info[key] for db_info in info.values()] ) # advanced @@ -292,10 +288,10 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [ADVANCED_FEATURES[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 + [ADVANCED_FEATURES[key]] + [db_info[key] for db_info in info.values()] ) - rows.append(["Score"] + list(db_info["score"] for db_info in info.values())) # noqa: C400 + rows.append(["Score"] + [db_info["score"] for db_info in info.values()]) return rows diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index c1f7e295dee..6e238e9ddfb 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -27,7 +27,7 @@ from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType from superset.models.sql_types.mssql_sql_types import GUID from superset.utils.core import GenericDataType @@ -52,10 +52,8 @@ CONNECTION_HOST_DOWN_REGEX = re.compile( class MssqlEngineSpec(BaseEngineSpec): engine = "mssql" engine_name = "Microsoft SQL Server" - limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 allows_cte_in_subquery = False - allow_limit_clause = False supports_multivalues_insert = True _time_grain_expressions = { diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index a7b97ed6996..75889c706c2 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -225,7 +225,6 @@ def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]: class OcientEngineSpec(BaseEngineSpec): engine = "ocient" engine_name = "Ocient" - # limit_method = LimitMethod.WRAP_SQL force_column_alias_quotes = True max_column_name_length = 30 diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py index f03cea49120..1df5736b824 100644 --- a/superset/db_engine_specs/oracle.py +++ b/superset/db_engine_specs/oracle.py @@ -20,13 +20,12 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec class OracleEngineSpec(BaseEngineSpec): engine = "oracle" engine_name = "Oracle" - limit_method = LimitMethod.WRAP_SQL force_column_alias_quotes = True max_column_name_length = 30 diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py index 887add24e90..08c9e9b7c99 100644 --- a/superset/db_engine_specs/teradata.py +++ b/superset/db_engine_specs/teradata.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec class TeradataEngineSpec(BaseEngineSpec): @@ -23,11 +23,8 @@ class TeradataEngineSpec(BaseEngineSpec): engine = "teradatasql" engine_name = "Teradata" - limit_method = LimitMethod.WRAP_SQL max_column_name_length = 30 # since 14.10 this is 128 - allow_limit_clause = False select_keywords = {"SELECT", "SEL"} - top_keywords = {"TOP", "SAMPLE"} _time_grain_expressions = { None: "{col}", diff --git a/superset/models/core.py b/superset/models/core.py index e1322be6f2a..bc545c95dc7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -761,11 +761,19 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ) def apply_limit_to_sql( - self, sql: str, limit: int = 1000, force: bool = False + self, + sql: str, + limit: int = 1000, + force: bool = False, ) -> str: - if self.db_engine_spec.allow_limit_clause: - return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force) - return self.db_engine_spec.apply_top_to_sql(sql, limit) + script = SQLScript(sql, self.db_engine_spec.engine) + statement = script.statements[-1] + current_limit = statement.get_limit_value() or float("inf") + + if limit < current_limit or force: + statement.set_limit_value(limit, self.db_engine_spec.limit_method) + + return script.format() def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 84017736cfe..d580a7a80e0 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -27,12 +27,9 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar import sqlglot -import sqlparse -from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError -from sqlglot.expressions import Func, Limit from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -99,6 +96,18 @@ SQLGLOT_DIALECTS = { } +class LimitMethod(enum.Enum): + """ + Limit methods. + + This is used to determine how to add a limit to a SQL statement. + """ + + FORCE_LIMIT = enum.auto() + WRAP_SQL = enum.auto() + FETCH_MANY = enum.auto() + + @dataclass(eq=True, frozen=True) class Table: """ @@ -252,6 +261,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Add a limit to the statement. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -412,34 +431,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Pretty-format the SQL statement. """ - if self._dialect: - try: - write = Dialect.get_or_raise(self._dialect) - return write.generate( - self._parsed, - copy=False, - comments=comments, - pretty=True, - ) - except ValueError: - pass - - return self._fallback_formatting() - - @deprecated(deprecated_in="4.0") - def _fallback_formatting(self) -> str: - """ - Format SQL without a specific dialect. - - Reformatting SQL using the generic sqlglot dialect is known to break queries. - For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which - breaks the query for Firebolt. To avoid this, we use sqlparse for formatting - when the dialect is not known. - - In 5.0 we should remove `sqlparse`, and the method should return the query - unmodified. - """ - return sqlparse.format(self._sql, reindent=True, keyword_case="upper") + return Dialect.get_or_raise(self._dialect).generate( + self._parsed, + copy=True, + comments=comments, + pretty=True, + ) def get_settings(self) -> dict[str, str | bool]: """ @@ -482,7 +479,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): if function.sql_name() != "ANONYMOUS" else function.name.upper() ) - for function in self._parsed.find_all(Func) + for function in self._parsed.find_all(exp.Func) } return any(function.upper() in present for function in functions) @@ -490,20 +487,38 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Parse a SQL query and return the `LIMIT` or `TOP` value, if present. """ - limit_node = ( - self._parsed - if isinstance(self._parsed, Limit) - else self._parsed.args.get("limit") - ) - if not isinstance(limit_node, exp.Limit): - return None - - literal = limit_node.args.get("expression") or getattr(limit_node, "this", None) - if isinstance(literal, exp.Literal) and literal.is_int: - return int(literal.name) + if limit_node := self._parsed.args.get("limit"): + literal = limit_node.args.get("expression") or getattr( + limit_node, "this", None + ) + if isinstance(literal, exp.Literal) and literal.is_int: + return int(literal.name) return None + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Modify the `LIMIT` or `TOP` value of the SQL statement inplace. + """ + if method == LimitMethod.FORCE_LIMIT: + self._parsed.args["limit"] = exp.Limit( + expression=exp.Literal(this=str(limit), is_string=False) + ) + elif method == LimitMethod.WRAP_SQL: + self._parsed = exp.Select( + expressions=[exp.Star()], + limit=exp.Limit( + expression=exp.Literal(this=str(limit), is_string=False) + ), + **{"from": exp.From(this=exp.Subquery(this=self._parsed.copy()))}, + ) + else: # method == LimitMethod.FETCH_MANY + pass + class KQLSplitState(enum.Enum): """ @@ -561,7 +576,7 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]: state = KQLSplitState.OUTSIDE_STRING tokens: list[tuple[KQLTokenType, str]] = [] buffer = "" - script = kql if kql.endswith(";") else kql + ";" + script = kql for i, ch in enumerate(script): if state == KQLSplitState.OUTSIDE_STRING: @@ -630,6 +645,9 @@ def split_kql(kql: str) -> list[str]: else: current.append((ttype, val)) + if current: + stmts_tokens.append(current) + return ["".join(val for _, val in stmt) for stmt in stmts_tokens] @@ -767,6 +785,40 @@ class KustoKQLStatement(BaseSQLStatement[str]): return None + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Add a limit to the statement. + """ + if method != LimitMethod.FORCE_LIMIT: + raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.") + + tokens = tokenize_kql(self._sql) + found_limit_token = False + for idx, (ttype, val) in enumerate(tokens): + if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}: + found_limit_token = True + + if found_limit_token and ttype == KQLTokenType.NUMBER: + tokens[idx] = (KQLTokenType.NUMBER, str(limit)) + break + else: + tokens.extend( + [ + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.WORD, "|"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.WORD, "take"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.NUMBER, str(limit)), + ] + ) + + self._parsed = self._sql = "".join(val for _, val in tokens) + class SQLScript: """ diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8fae4507efa..e141a4f2900 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -469,39 +469,6 @@ class ParsedQuery: exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" return exec_sql - def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: - """Returns the query with the specified limit. - - Does not change the underlying query if user did not apply the limit, - otherwise replaces the limit with the lower value between existing limit - in the query and new_limit. - - :param new_limit: Limit to be incorporated into returned query - :return: The original query with new limit - """ - if not self._limit: - return f"{self.stripped()}\nLIMIT {new_limit}" - limit_pos = None - statement = self._parsed[0] - # Add all items to before_str until there is a limit - for pos, item in enumerate(statement.tokens): - if item.ttype in Keyword and item.value.lower() == "limit": - limit_pos = pos - break - _, limit = statement.token_next(idx=limit_pos) - # Override the limit only when it exceeds the configured value. - if limit.ttype == sqlparse.tokens.Literal.Number.Integer and ( - force or new_limit < int(limit.value) - ): - limit.value = new_limit - elif limit.is_group: - limit.value = f"{next(limit.get_identifiers())}, {new_limit}" - - str_res = "" - for i in statement.tokens: - str_res += str(i.value) - return str_res - def sanitize_clause(clause: str) -> str: # clause = sqlparse.format(clause, strip_comments=True) diff --git a/tests/integration_tests/db_engine_specs/ascend_tests.py b/tests/integration_tests/db_engine_specs/ascend_tests.py index cd1fa372858..045cac7d76a 100644 --- a/tests/integration_tests/db_engine_specs/ascend_tests.py +++ b/tests/integration_tests/db_engine_specs/ascend_tests.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. from superset.db_engine_specs.ascend import AscendEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestAscendDbEngineSpec(TestDbEngineSpec): +class TestAscendDbEngineSpec(SupersetTestCase): def test_convert_dttm(self): dttm = self.get_dttm() diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 7dd9c5cb95e..46a5951f10f 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -25,14 +25,13 @@ from superset.db_engine_specs.base import ( BaseEngineSpec, BasicParametersMixin, builtin_time_grains, - LimitMethod, ) from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app from ..fixtures.birth_names_dashboard import ( @@ -46,7 +45,7 @@ from ..fixtures.energy_dashboard import ( from ..fixtures.pyodbcRow import Row -class TestDbEngineSpecs(TestDbEngineSpec): +class SupersetTestCases(SupersetTestCase): def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec): q0 = "select * from table" q1 = "select * from mytable limit 10" @@ -74,124 +73,9 @@ class TestDbEngineSpecs(TestDbEngineSpec): assert engine_spec_class.get_limit_from_sql(q10) is None assert engine_spec_class.get_limit_from_sql(q11) is None - def test_wrapped_semi_tabs(self): - self.sql_limit_regex( - "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000" - ) - - def test_simple_limit_query(self): - self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000") - - def test_modify_limit_query(self): - self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000") - - def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name - self.sql_limit_regex( - "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999", - "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000", - ) - - def test_limit_query_without_force(self): - self.sql_limit_regex( - "SELECT * FROM a LIMIT 10", - "SELECT * FROM a LIMIT 10", - limit=11, - ) - - def test_limit_query_with_force(self): - self.sql_limit_regex( - "SELECT * FROM a LIMIT 10", - "SELECT * FROM a LIMIT 11", - limit=11, - force=True, - ) - - def test_limit_with_expr(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000""", - ) - - def test_limit_expr_and_semicolon(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990 ;""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000""", - ) - def test_get_datatype(self): assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR") - def test_limit_with_implicit_offset(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990, 999999""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990, 1000""", - ) - - def test_limit_with_explicit_offset(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990 - OFFSET 999999""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000 - OFFSET 999999""", - ) - - def test_limit_with_non_token_limit(self): - self.sql_limit_regex( - """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000""" - ) - - def test_limit_with_fetch_many(self): - class DummyEngineSpec(BaseEngineSpec): - limit_method = LimitMethod.FETCH_MANY - - self.sql_limit_regex( - "SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec - ) - def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py deleted file mode 100644 index c836e71b689..00000000000 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file - -from tests.integration_tests.test_app import app # noqa: F401 -from tests.integration_tests.base_tests import SupersetTestCase -from superset.db_engine_specs.base import BaseEngineSpec -from superset.models.core import Database - - -class TestDbEngineSpec(SupersetTestCase): - def sql_limit_regex( - self, - sql, - expected_sql, - engine_spec_class=BaseEngineSpec, - limit=1000, - force=False, - ): - main = Database(database_name="test_database", sqlalchemy_uri="sqlite://") - limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force) - assert expected_sql == limited diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 636fc3523ae..00b8414127a 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -26,7 +26,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 load_birth_names_data, # noqa: F401 @@ -42,7 +42,7 @@ def mock_engine_with_credentials(*args, **kwargs): yield engine_mock -class TestBigQueryDbEngineSpec(TestDbEngineSpec): +class TestBigQueryDbEngineSpec(SupersetTestCase): def test_bigquery_sqla_column_label(self): """ DB Eng Specs (bigquery): Test column label diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py index bf4d7e8b9f9..ec6ed2964ef 100644 --- a/tests/integration_tests/db_engine_specs/databricks_tests.py +++ b/tests/integration_tests/db_engine_specs/databricks_tests.py @@ -18,12 +18,12 @@ from unittest import mock from superset.db_engine_specs import get_engine_spec from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra -class TestDatabricksDbEngineSpec(TestDbEngineSpec): +class TestDatabricksDbEngineSpec(SupersetTestCase): def test_get_engine_spec(self): """ DB Eng Specs (databricks): Test "databricks" in engine spec diff --git a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py index 8027c031a5d..2ac4f6aa2a8 100644 --- a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py +++ b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py @@ -19,10 +19,10 @@ from sqlalchemy import column from superset.constants import TimeGrain from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestElasticsearchDbEngineSpec(TestDbEngineSpec): +class TestElasticsearchDbEngineSpec(SupersetTestCase): @parameterized.expand( [ [TimeGrain.SECOND, "DATE_TRUNC('second', ts)"], diff --git a/tests/integration_tests/db_engine_specs/gsheets_tests.py b/tests/integration_tests/db_engine_specs/gsheets_tests.py index 212af15c333..6368d730d47 100644 --- a/tests/integration_tests/db_engine_specs/gsheets_tests.py +++ b/tests/integration_tests/db_engine_specs/gsheets_tests.py @@ -16,10 +16,10 @@ # under the License. from superset.db_engine_specs.gsheets import GSheetsEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestGsheetsDbEngineSpec(TestDbEngineSpec): +class TestGsheetsDbEngineSpec(SupersetTestCase): def test_extract_errors(self): """ Test that custom error messages are extracted correctly. diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index c6bbfb683ff..05e31bbee55 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -17,7 +17,6 @@ # isort:skip_file from unittest import mock import unittest -from .base_tests import SupersetTestCase import pytest import pandas as pd @@ -26,6 +25,7 @@ from sqlalchemy.sql import select from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException from superset.sql_parse import Table +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index 23af61f17d9..2698721c651 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -21,12 +21,12 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): +class TestMySQLEngineSpecsDbEngineSpec(SupersetTestCase): @unittest.skipUnless( - TestDbEngineSpec.is_module_installed("MySQLdb"), "mysqlclient not installed" + SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" ) def test_get_datatype_mysql(self): """Tests related to datatype mapping for MySQL""" diff --git a/tests/integration_tests/db_engine_specs/pinot_tests.py b/tests/integration_tests/db_engine_specs/pinot_tests.py index 66d4865fb81..f6872a2e2bd 100755 --- a/tests/integration_tests/db_engine_specs/pinot_tests.py +++ b/tests/integration_tests/db_engine_specs/pinot_tests.py @@ -17,10 +17,10 @@ from sqlalchemy import column from superset.db_engine_specs.pinot import PinotEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestPinotDbEngineSpec(TestDbEngineSpec): +class TestPinotDbEngineSpec(SupersetTestCase): """Tests pertaining to our Pinot database support""" def test_pinot_time_expression_sec_one_1d_grain(self): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e45e0189f40..236d293df4b 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -27,12 +27,12 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query from superset.utils.core import backend from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra -class TestPostgresDbEngineSpec(TestDbEngineSpec): +class TestPostgresDbEngineSpec(SupersetTestCase): def test_get_table_names(self): """ DB Eng Specs (postgres): Test get table names diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index c57bec88008..2a27d12df9a 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -27,11 +27,11 @@ from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestPrestoDbEngineSpec(TestDbEngineSpec): - @skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed") +class TestPrestoDbEngineSpec(SupersetTestCase): + @skipUnless(SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed") def test_get_datatype_presto(self): assert "STRING" == PrestoEngineSpec.get_datatype("string") diff --git a/tests/integration_tests/db_engine_specs/redshift_tests.py b/tests/integration_tests/db_engine_specs/redshift_tests.py index 2d46c73fca7..38d3bc091c5 100644 --- a/tests/integration_tests/db_engine_specs/redshift_tests.py +++ b/tests/integration_tests/db_engine_specs/redshift_tests.py @@ -24,11 +24,11 @@ from sqlalchemy.types import NVARCHAR from superset.db_engine_specs.redshift import RedshiftEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app -class TestRedshiftDbEngineSpec(TestDbEngineSpec): +class TestRedshiftDbEngineSpec(SupersetTestCase): def test_extract_errors(self): """ Test that custom error messages are extracted correctly. diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 7d0ea7cc2ef..dad966c6de6 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -33,7 +33,9 @@ from tests.integration_tests.test_app import app from superset import db, sql_lab from superset.common.db_query_status import QueryStatus from superset.models.core import Database # noqa: F401 -from superset.utils.database import get_example_database, get_main_database # noqa: F401 +from superset.utils.database import ( + get_example_database, +) # noqa: F401 from superset.utils import core as utils, json from superset.models.sql_lab import Query @@ -281,7 +283,7 @@ class TestSqlLabApi(SupersetTestCase): "/api/v1/sqllab/format_sql/", json=data, ) - success_resp = {"result": "SELECT 1\nFROM my_table"} + success_resp = {"result": "SELECT\n 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) self.assertDictEqual(resp_data, success_resp) # noqa: PT009 assert rv.status_code == 200 diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index c100007d383..e4e76472b53 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -206,9 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None: """ from superset.db_engine_specs.base import BaseEngineSpec - class NoLimitDBEngineSpec(BaseEngineSpec): - allow_limit_clause = False - cols: list[ResultSetColumnType] = [ { "column_name": "a", @@ -243,19 +240,7 @@ def test_select_star(mocker: MockerFixture) -> None: latest_partition=False, cols=cols, ) - assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?" - - sql = NoLimitDBEngineSpec.select_star( - database=database, - table=Table("my_table"), - engine=engine, - limit=100, - show_cols=True, - indent=True, - latest_partition=False, - cols=cols, - ) - assert sql == "SELECT a\nFROM my_table" + assert sql == "SELECT\n a\nFROM my_table\nLIMIT ?\nOFFSET ?" def test_extra_table_metadata(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 35fa91108ab..62832c4f3cf 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -254,36 +254,6 @@ def test_cte_query_parsing(original: TypeEngine, expected: str) -> None: assert actual == expected -@pytest.mark.parametrize( - "original,expected,top", - [ - ("SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table", 100), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table", 100), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 10000), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 1000), - ( - """with abc as (select * from test union select * from test1) -select TOP 100 * from currency""", - """WITH abc as (select * from test union select * from test1) -select TOP 100 * from currency""", - 1000, - ), - ("SELECT DISTINCT x from tbl", "SELECT DISTINCT TOP 100 x from tbl", 100), - ("SELECT 1 as cnt", "SELECT TOP 10 1 as cnt", 10), - ( - "select TOP 1000 * from abc where id=1", - "select TOP 10 * from abc where id=1", - 10, - ), - ], -) -def test_top_query_parsing(original: TypeEngine, expected: str, top: int) -> None: - from superset.db_engine_specs.mssql import MssqlEngineSpec - - actual = MssqlEngineSpec.apply_top_to_sql(original, top) - assert actual == expected - - def test_extract_errors() -> None: """ Test that custom error messages are extracted correctly. diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py deleted file mode 100644 index eab03e040d5..00000000000 --- a/tests/unit_tests/db_engine_specs/test_teradata.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument, import-outside-toplevel, protected-access -import pytest - - -@pytest.mark.parametrize( - "limit,original,expected", - [ - (100, "SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table"), - (100, "SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table"), - (10000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), - (1000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), - (100, "SELECT TOP 1000 * FROM My_table", "SELECT TOP 100 * FROM My_table"), - (100, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 100 * FROM My_table"), - (10000, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 1000 * FROM My_table"), - ], -) -def test_apply_top_to_sql_limit( - limit: int, - original: str, - expected: str, -) -> None: - """ - Ensure limits are applied to the query correctly - """ - from superset.db_engine_specs.teradata import TeradataEngineSpec - - assert TeradataEngineSpec.apply_top_to_sql(original, limit) == expected diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 50ccde66055..981fab3e8b6 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -38,6 +38,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.models.core import Database +from superset.sql.parse import LimitMethod from superset.sql_parse import Table from superset.utils import json from tests.unit_tests.conftest import with_feature_flags @@ -910,3 +911,144 @@ def test_get_all_view_names_in_schema(mocker: MockerFixture) -> None: ("third_view", "public", "examples"), } ) + + +@pytest.mark.parametrize( + "sql, limit, force, method, expected", + [ + ( + "SELECT * FROM table", + 100, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 100", + ), + ( + "SELECT * FROM table LIMIT 100", + 10, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 10", + ), + ( + "SELECT * FROM table LIMIT 10", + 100, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 10", + ), + ( + "SELECT * FROM table LIMIT 10", + 100, + True, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 100", + ), + ( + "SELECT * FROM a \t \n ; \t \n ", + 1000, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM a\nLIMIT 1000", + ), + ( + "SELECT 'LIMIT 777'", + 1000, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777'\nLIMIT 1000", + ), + ( + "SELECT * FROM table", + 1000, + False, + LimitMethod.FETCH_MANY, + "SELECT\n *\nFROM table", + ), + ( + "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999", + 1000, + False, + LimitMethod.FORCE_LIMIT, + """SELECT + * +FROM ( + SELECT + * + FROM a + LIMIT 10 +) +LIMIT 1000""", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM + table +LIMIT 99990""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990 ;""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990, 999999""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 99990", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990 +OFFSET 999999""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 999999", + ), + ], +) +def test_apply_limit_to_sql( + sql: str, + limit: int, + force: bool, + method: LimitMethod, + expected: str, + mocker: MockerFixture, +) -> None: + """ + Test the `apply_limit_to_sql` method. + """ + db = Database(database_name="test_database", sqlalchemy_uri="sqlite://") + db_engine_spec = mocker.MagicMock(limit_method=method) + db.get_db_engine_spec = mocker.MagicMock(return_value=db_engine_spec) + + limited = db.apply_limit_to_sql(sql, limit, force) + assert limited == expected diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index f8da1d77b2a..daf5ebe71df 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -24,6 +24,7 @@ from superset.exceptions import SupersetParseError from superset.sql.parse import ( extract_tables_from_statement, KustoKQLStatement, + LimitMethod, split_kql, SQLGLOT_DIALECTS, SQLScript, @@ -302,7 +303,11 @@ def test_format_no_dialect() -> None: """ assert ( SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format() - == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)" + == """SELECT + col +FROM t +WHERE + NOT col IN (1, 2)""" ) @@ -1100,16 +1105,18 @@ FROM ( WHERE TRUE AND TRUE""" - not_optimized = """ -SELECT anon_1.a, - anon_1.b -FROM - (SELECT some_table.a AS a, - some_table.b AS b, - some_table.c AS c - FROM some_table) AS anon_1 -WHERE anon_1.a > 1 - AND anon_1.b = 2""" + not_optimized = """SELECT + anon_1.a, + anon_1.b +FROM ( + SELECT + some_table.a AS a, + some_table.b AS b, + some_table.c AS c + FROM some_table +) AS anon_1 +WHERE + anon_1.a > 1 AND anon_1.b = 2""" assert SQLStatement(sql, "sqlite").optimize().format() == optimized assert SQLStatement(sql, "dremio").optimize().format() == not_optimized @@ -1191,6 +1198,18 @@ def test_firebolt_old_escape_string() -> None: "sql, engine, expected", [ ("SELECT * FROM users LIMIT 10", "postgresql", 10), + ( + """ +WITH cte_example AS ( + SELECT * FROM my_table + LIMIT 100 +) +SELECT * FROM cte_example +LIMIT 10; + """, + "postgresql", + 10, + ), ("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25), ("SELECT * FROM users", "postgresql", None), ("SELECT TOP 5 name FROM employees", "teradatasql", 5), @@ -1221,7 +1240,7 @@ LATERAL generate_series(1, value) AS i; ), ], ) -def test_get_limit_value(sql, engine, expected): +def test_get_limit_value(sql: str, engine: str, expected: str) -> None: assert SQLStatement(sql, engine).get_limit_value() == expected @@ -1243,5 +1262,232 @@ def test_get_limit_value(sql, engine, expected): ), ], ) -def test_get_kql_limit_value(kql, expected): +def test_get_kql_limit_value(kql: str, expected: str) -> None: assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected + + +@pytest.mark.parametrize( + "sql, engine, limit, method, expected", + [ + ( + "SELECT * FROM t", + "postgresql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nLIMIT 10", + ), + ( + "SELECT * FROM t LIMIT 1000", + "postgresql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nLIMIT 10", + ), + ( + "SELECT * FROM t", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM t", + ), + ( + "SELECT * FROM t", + "teradatasql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM t", + ), + ( + "SELECT * FROM t", + "oracle", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY", + ), + ( + "SELECT * FROM t", + "db2", + 10, + LimitMethod.WRAP_SQL, + "SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10", + ), + ( + "SEL TOP 1000 * FROM My_table", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SEL TOP 1000 * FROM My_table;", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SEL TOP 1000 * FROM My_table;", + "teradatasql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 1000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "teradatasql", + 10000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 10000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 1000\n *\nFROM My_table", + ), + ( + """ +with abc as (select * from test union select * from test1) +select TOP 100 * from currency + """, + "mssql", + 1000, + LimitMethod.FORCE_LIMIT, + """WITH abc AS ( + SELECT + * + FROM test + UNION + SELECT + * + FROM test1 +) +SELECT +TOP 1000 + * +FROM currency""", + ), + ( + "SELECT DISTINCT x from tbl", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT DISTINCT\nTOP 100\n x\nFROM tbl", + ), + ( + "SELECT 1 as cnt", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n 1 AS cnt", + ), + ( + "select TOP 1000 * from abc where id=1", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1", + ), + ( + "SELECT * FROM birth_names -- SOME COMMENT", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000", + ), + ( + "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + """SELECT + * +FROM birth_names /* SOME COMMENT WITH LIMIT 555 */ +LIMIT 1000""", + ), + ( + "SELECT * FROM birth_names LIMIT 555", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM birth_names\nLIMIT 1000", + ), + ], +) +def test_set_limit_value( + sql: str, + engine: str, + limit: int, + method: LimitMethod, + expected: str, +) -> None: + statement = SQLStatement(sql, engine) + statement.set_limit_value(limit, method) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "kql, limit, expected", + [ + ("StormEvents | take 10", 100, "StormEvents | take 100"), + ("StormEvents | limit 20", 10, "StormEvents | limit 10"), + ( + "StormEvents | where State == 'FL' | summarize count()", + 10, + "StormEvents | where State == 'FL' | summarize count() | take 10", + ), + ( + "StormEvents | where name has 'limit 10'", + 10, + "StormEvents | where name has 'limit 10' | take 10", + ), + ("AnotherTable | take 5", 50, "AnotherTable | take 50"), + ( + "datatable(x:int) [1, 2, 3] | take 100", + 10, + "datatable(x:int) [1, 2, 3] | take 10", + ), + ( + """ + Table1 | where msg contains 'abc;xyz' + | limit 5 + """, + 10, + """Table1 | where msg contains 'abc;xyz' + | limit 10""", + ), + ], +) +def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None: + statement = KustoKQLStatement(kql, "kustokql") + statement.set_limit_value(limit) + assert statement.format() == expected diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 89d7dbcb06d..dd04e38dcb6 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -297,7 +297,7 @@ def test_sql_lab_insert_rls_as_subquery( | 3 | 3 | | 4 | 4 |""".strip() ) - assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" + assert query.executed_sql == "SELECT\n c\nFROM t\nLIMIT 6" # now with RLS rls = RowLevelSecurityFilter( @@ -333,7 +333,18 @@ def test_sql_lab_insert_rls_as_subquery( ) assert ( query.executed_sql - == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6" + == """SELECT + c +FROM ( + SELECT + * + FROM t + WHERE + ( + t.c > 5 + ) +) AS t +LIMIT 6""" ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 23aa6b0b125..f0c55e4a267 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1104,46 +1104,6 @@ def test_unknown_select() -> None: assert not ParsedQuery(sql).is_select() -def test_get_query_with_new_limit_comment() -> None: - """ - Test that limit is applied correctly. - """ - query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names -- SOME COMMENT\nLIMIT 1000" - ) - - -def test_get_query_with_new_limit_comment_with_limit() -> None: - """ - Test that limits in comments are ignored. - """ - query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555\nLIMIT 1000" - ) - - -def test_get_query_with_new_limit_lower() -> None: - """ - Test that lower limits are not replaced. - """ - query = ParsedQuery("SELECT * FROM birth_names LIMIT 555") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names LIMIT 555" - ) - - -def test_get_query_with_new_limit_upper() -> None: - """ - Test that higher limits are replaced. - """ - query = ParsedQuery("SELECT * FROM birth_names LIMIT 2000") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names LIMIT 1000" - ) - - def test_basic_breakdown_statements() -> None: """ Test that multiple statements are parsed correctly.