diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 9e006c28099..dc9ca632bba 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -32,6 +32,7 @@ from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError +from sqlglot.expressions import Func from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): return SQLStatement(sql, self.engine, optimized) + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param functions: List of functions to check for + :return: True if any of the functions are present + """ + present = { + ( + function.sql_name() + if function.sql_name() != "ANONYMOUS" + else function.name.upper() + ) + for function in self._parsed.find_all(Func) + } + return any(function.upper() in present for function in functions) + class KQLSplitState(enum.Enum): """ @@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]): """ return KustoKQLStatement(self._sql, self.engine, self._parsed) + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param functions: List of functions to check for + :return: True if any of the functions are present + """ + logger.warning("Kusto KQL doesn't support checking for functions present.") + return True + class SQLScript: """ @@ -684,6 +712,18 @@ class SQLScript: return script + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param functions: List of functions to check for + :return: True if any of the functions are present + """ + return any( + statement.check_functions_present(functions) + for statement in self.statements + ) + def extract_tables_from_statement( statement: exp.Expression, diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 2cd8d0ac9dc..8fae4507efa 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -31,7 +31,6 @@ from sqlalchemy import and_ from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( - Function, Identifier, IdentifierList, Parenthesis, @@ -181,7 +180,7 @@ def check_sql_functions_exist( :param function_list: The list of functions to search for :param engine: The engine to use for parsing the SQL statement """ - return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) + return SQLScript(sql, engine=engine).check_functions_present(function_list) def strip_comments_from_sql(statement: str, engine: str = "base") -> str: @@ -229,34 +228,6 @@ class ParsedQuery: self._tables = self._extract_tables_from_sql() return self._tables - def _check_functions_exist_in_token( - self, token: Token, functions: set[str] - ) -> bool: - if ( - isinstance(token, Function) - and token.get_name() is not None - and token.get_name().lower() in functions - ): - return True - if hasattr(token, "tokens"): - for inner_token in token.tokens: - if self._check_functions_exist_in_token(inner_token, functions): - return True - return False - - def check_functions_exist(self, functions: set[str]) -> bool: - """ - Check if the SQL statement contains any of the specified functions. - - :param functions: A set of functions to search for - :return: True if the statement contains any of the specified functions - """ - for statement in self._parsed: - for token in statement.tokens: - if self._check_functions_exist_in_token(token, functions): - return True - return False - def _extract_tables_from_sql(self) -> set[Table]: """ Extract all table references in a query. diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index c3533d131ee..357dcade337 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -580,10 +580,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data ) rv = test_client.post(uri, json={}) assert rv.status_code == 422 - - assert "error" in rv.json - if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL": - assert "INCORRECT SQL" in rv.json.get("error") + assert rv.json["errors"][0]["error_type"] == "INVALID_SQL_ERROR" @with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 9c814a0f422..23aa6b0b125 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None: ) +def test_check_sql_functions_exist_with_comments() -> None: + """ + Test sql functions are detected correctly with comments + """ + assert not ( + check_sql_functions_exist( + "select a, b from version/**/", {"version"}, "postgresql" + ) + ) + + assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql") + + assert check_sql_functions_exist( + "select version from version/**/()", {"version"}, "postgresql" + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version from version/**/()) as a", + {"version"}, + "postgresql", + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version/**/()) as a", + {"version"}, + "postgresql", + ) + + def test_sanitize_clause_valid(): # regular clauses assert sanitize_clause("col = 1") == "col = 1"