diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 70526ea12ee..be2b3ee7d1a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -2070,10 +2070,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return False - @classmethod - def parse_sql(cls, sql: str) -> list[str]: - return [str(s).strip(" ;") for s in sqlparse.parse(sql)] - @classmethod def get_impersonation_key(cls, user: User | None) -> Any: """ diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 2081f6c89ce..9181b078592 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -153,11 +153,3 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method return f"""datetime({dttm.isoformat(timespec="microseconds")})""" return None - - @classmethod - def parse_sql(cls, sql: str) -> list[str]: - """ - Kusto supports a single query statement, but it could include sub queries - and variables declared via let keyword. - """ - return [sql] diff --git a/superset/models/core.py b/superset/models/core.py index bc545c95dc7..1768aa41fdd 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -659,7 +659,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable schema: str | None = None, mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: - sqls = self.db_engine_spec.parse_sql(sql) + script = SQLScript(sql, self.db_engine_spec.engine) with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url @@ -676,8 +676,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with self.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() df = None - for i, sql_ in enumerate(sqls): - sql_ = self.mutate_sql_based_on_config(sql_, is_split=True) + for i, statement in enumerate(script.statements): + sql_ = self.mutate_sql_based_on_config( + statement.format(), + is_split=True, + ) _log_query(sql_) with event_logger.log_context( action="execute_sql", @@ -686,7 +689,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ): self.db_engine_spec.execute(cursor, sql_, self) - rows = self.fetch_rows(cursor, i == len(sqls) - 1) + rows = self.fetch_rows(cursor, i == len(script.statements) - 1) if rows is not None: df = self.load_into_dataframe(cursor.description, rows) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index e4e76472b53..421be19eb6d 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -49,32 +49,6 @@ def test_get_text_clause_with_colon() -> None: assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')" -def test_parse_sql_single_statement() -> None: - """ - `parse_sql` should properly strip leading and trailing spaces and semicolons - """ - - from superset.db_engine_specs.base import BaseEngineSpec - - queries = BaseEngineSpec.parse_sql(" SELECT foo FROM tbl ; ") - assert queries == ["SELECT foo FROM tbl"] - - -def test_parse_sql_multi_statement() -> None: - """ - For string with multiple SQL-statements `parse_sql` method should return list - where each element represents the single SQL-statement - """ - - from superset.db_engine_specs.base import BaseEngineSpec - - queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;") - assert queries == [ - "SELECT foo FROM tbl1", - "SELECT bar FROM tbl2", - ] - - def test_validate_db_uri(mocker: MockerFixture) -> None: """ Ensures that the `validate_database_uri` method invokes the validator correctly diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index d3a49f86e9c..a21e82a5676 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -80,19 +80,6 @@ def test_kql_has_mutation(kql: str, expected: bool) -> None: ) -def test_kql_parse_sql() -> None: - """ - parse_sql method should always return a list with a single element - which is an original query - """ - - from superset.db_engine_specs.kusto import KustoKqlEngineSpec - - queries = KustoKqlEngineSpec.parse_sql("let foo = 1; tbl | where bar == foo") - - assert queries == ["let foo = 1; tbl | where bar == foo"] - - @pytest.mark.parametrize( "target_type,expected_result", [ diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index d4341524c13..3d0746b5c79 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -747,6 +747,34 @@ Events | take 100""", assert query.get_settings() == {"querytrace": True} +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ( + " SELECT foo FROM tbl ; ", + "postgresql", + ["SELECT\n foo\nFROM tbl"], + ), + ( + "SELECT foo FROM tbl1; SELECT bar FROM tbl2;", + "postgresql", + ["SELECT\n foo\nFROM tbl1", "SELECT\n bar\nFROM tbl2"], + ), + ( + "let foo = 1; tbl | where bar == foo", + "kustokql", + ["let foo = 1", "tbl | where bar == foo"], + ), + ], +) +def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None: + """ + Test the `SQLScript` class with a script that has a single statement. + """ + script = SQLScript(sql, engine) + assert [statement.format() for statement in script.statements] == expected + + def test_sqlstatement() -> None: """ Test the `SQLStatement` class.