diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 1ca100975f2..fbda84570e3 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -108,6 +108,11 @@ class LimitMethod(enum.Enum): FETCH_MANY = enum.auto() +class CTASMethod(enum.Enum): + TABLE = enum.auto() + VIEW = enum.auto() + + class RLSMethod(enum.Enum): """ Methods for enforcing RLS. @@ -381,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + raise NotImplementedError() + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -437,6 +448,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new SQLStatement with the CTE. + """ + raise NotImplementedError() + def apply_rls( self, catalog: str | None, @@ -480,7 +501,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ dialect = SQLGLOT_DIALECTS.get(engine) try: - return sqlglot.parse(script, dialect=dialect) + statements = sqlglot.parse(script, dialect=dialect) except sqlglot.errors.ParseError as ex: error = ex.errors[0] raise SupersetParseError( @@ -497,6 +518,20 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): message="Unable to parse script", ) from ex + # `sqlglot` will parse comments after the last semicolon as a separate + # statement; move them back to the last token in the last real statement + if len(statements) > 1 and isinstance(statements[-1], exp.Semicolon): + last_statement = statements.pop() + target = statements[-1] + for node in statements[-1].walk(): + if hasattr(node, "comments"): + target = node + + target.comments = target.comments or [] + target.comments.extend(last_statement.comments) + + return statements + @classmethod def split_script( cls, @@ -572,6 +607,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): dialect = SQLGLOT_DIALECTS.get(engine) return extract_tables_from_statement(parsed, dialect) + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return isinstance(self._parsed, exp.Select) + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -733,6 +774,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): engine=self.engine, ) + def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new SQLStatement with the create table statement. + """ + create_table = exp.Create( + this=sqlglot.parse_one(str(table), into=exp.Table), + kind=method.name, + expression=self._parsed.copy(), + ) + + return SQLStatement(ast=create_table, engine=self.engine) + def apply_rls( self, catalog: str | None, @@ -988,6 +1045,12 @@ class KustoKQLStatement(BaseSQLStatement[str]): return {} + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return not self._parsed.startswith(".") + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -1142,6 +1205,24 @@ class SQLScript: for statement in self.statements ) + def is_valid_ctas(self) -> bool: + """ + Check if the script contains a valid CTAS statement. + + CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last + statement is a `SELECT`. + """ + return self.statements[-1].is_select() + + def is_valid_cvas(self) -> bool: + """ + Check if the script contains a valid CVAS statement. + + CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single + `SELECT` statement. + """ + return len(self.statements) == 1 and self.statements[0].is_select() + def extract_tables_from_statement( statement: exp.Expression, diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 72907f92c7b..df837518b52 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -22,6 +22,7 @@ from sqlglot import Dialects, parse_one from superset.exceptions import SupersetParseError from superset.sql.parse import ( + CTASMethod, extract_tables_from_statement, KustoKQLStatement, LimitMethod, @@ -768,6 +769,11 @@ Events | take 100""", "kustokql", ["let foo = 1", "tbl | where bar == foo"], ), + ( + "SELECT 1; -- extraneous comment", + "postgresql", + ["SELECT\n 1 /* extraneous comment */"], + ), ], ) def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None: @@ -2247,3 +2253,135 @@ def test_rls_predicate_transformer( RLSMethod.AS_PREDICATE, ) assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, table, expected", + [ + ( + "SELECT * FROM some_table", + Table("some_table"), + """ +CREATE TABLE some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ( + "SELECT * FROM some_table", + Table("some_table", "schema1", "catalog1"), + """ +CREATE TABLE catalog1.schema1.some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ], +) +def test_as_create_table(sql: str, table: Table, expected: str) -> None: + """ + Test the `as_create_table` method. + """ + statement = SQLStatement(sql) + create_table = statement.as_create_table(table, CTASMethod.TABLE) + assert create_table.format() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +EXPLAIN SELECT * FROM table +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_ctas` method. + """ + assert SQLScript(sql, engine).is_valid_ctas() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +-- comment +SELECT value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_cvas` method. + """ + assert SQLScript(sql, engine).is_valid_cvas() == expected