feat: implement CVAS/CTAS in sqlglot (#33525)

This commit is contained in:
Beto Dealmeida
2025-05-28 09:45:59 -04:00
committed by GitHub
parent 0abe6eed89
commit ea5a609d0b
2 changed files with 220 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from sqlglot import Dialects, parse_one
from superset.exceptions import SupersetParseError
from superset.sql.parse import (
CTASMethod,
extract_tables_from_statement,
KustoKQLStatement,
LimitMethod,
@@ -768,6 +769,11 @@ Events | take 100""",
"kustokql",
["let foo = 1", "tbl | where bar == foo"],
),
(
"SELECT 1; -- extraneous comment",
"postgresql",
["SELECT\n 1 /* extraneous comment */"],
),
],
)
def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None:
@@ -2247,3 +2253,135 @@ def test_rls_predicate_transformer(
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, table, expected",
[
(
"SELECT * FROM some_table",
Table("some_table"),
"""
CREATE TABLE some_table AS
SELECT
*
FROM some_table
""".strip(),
),
(
"SELECT * FROM some_table",
Table("some_table", "schema1", "catalog1"),
"""
CREATE TABLE catalog1.schema1.some_table AS
SELECT
*
FROM some_table
""".strip(),
),
],
)
def test_as_create_table(sql: str, table: Table, expected: str) -> None:
"""
Test the `as_create_table` method.
"""
statement = SQLStatement(sql)
create_table = statement.as_create_table(table, CTASMethod.TABLE)
assert create_table.format() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT * FROM table", "postgresql", True),
(
"""
-- comment
SELECT * FROM table
-- comment 2
""",
"mysql",
True,
),
(
"""
-- comment
SET @value = 42;
SELECT @value as foo;
-- comment 2
""",
"mysql",
True,
),
(
"""
-- comment
EXPLAIN SELECT * FROM table
-- comment 2
""",
"mysql",
False,
),
(
"""
SELECT * FROM table;
INSERT INTO TABLE (foo) VALUES (42);
""",
"mysql",
False,
),
],
)
def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None:
"""
Test the `is_valid_ctas` method.
"""
assert SQLScript(sql, engine).is_valid_ctas() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT * FROM table", "postgresql", True),
(
"""
-- comment
SELECT * FROM table
-- comment 2
""",
"mysql",
True,
),
(
"""
-- comment
SET @value = 42;
SELECT @value as foo;
-- comment 2
""",
"mysql",
False,
),
(
"""
-- comment
SELECT value as foo;
-- comment 2
""",
"mysql",
True,
),
(
"""
SELECT * FROM table;
INSERT INTO TABLE (foo) VALUES (42);
""",
"mysql",
False,
),
],
)
def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None:
"""
Test the `is_valid_cvas` method.
"""
assert SQLScript(sql, engine).is_valid_cvas() == expected