mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
feat(sqllab): use sqlglot instead of sqlparse (#33542)
This commit is contained in:
@@ -36,7 +36,7 @@ from sqlalchemy.dialects.mysql import dialect
|
||||
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME
|
||||
from tests.integration_tests.test_app import app, login
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.sql.parse import CTASMethod
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import BaseDatasource, SqlaTable
|
||||
from superset.models import core as models
|
||||
@@ -387,7 +387,7 @@ class SupersetTestCase(TestCase):
|
||||
select_as_cta=False,
|
||||
tmp_table_name=None,
|
||||
schema=None,
|
||||
ctas_method=CtasMethod.TABLE,
|
||||
ctas_method=CTASMethod.TABLE,
|
||||
template_params="{}",
|
||||
):
|
||||
if username:
|
||||
@@ -400,7 +400,7 @@ class SupersetTestCase(TestCase):
|
||||
"client_id": client_id,
|
||||
"queryLimit": query_limit,
|
||||
"sql_editor_id": sql_editor_id,
|
||||
"ctas_method": ctas_method,
|
||||
"ctas_method": ctas_method.name,
|
||||
"templateParams": template_params,
|
||||
}
|
||||
if tmp_table_name:
|
||||
|
||||
@@ -39,7 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery, CtasMethod
|
||||
from superset.sql.parse import CTASMethod
|
||||
from superset.utils.core import backend
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.conftest import CTAS_SCHEMA_NAME
|
||||
@@ -76,13 +76,19 @@ def setup_sqllab():
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
for tbl in TMP_TABLES:
|
||||
drop_table_if_exists(f"{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE)
|
||||
drop_table_if_exists(f"{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW)
|
||||
drop_table_if_exists(
|
||||
f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE
|
||||
f"{tbl}_{CTASMethod.TABLE.name.lower()}", CTASMethod.TABLE
|
||||
)
|
||||
drop_table_if_exists(
|
||||
f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW
|
||||
f"{tbl}_{CTASMethod.VIEW.name.lower()}", CTASMethod.VIEW
|
||||
)
|
||||
drop_table_if_exists(
|
||||
f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.TABLE.name.lower()}",
|
||||
CTASMethod.TABLE,
|
||||
)
|
||||
drop_table_if_exists(
|
||||
f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.VIEW.name.lower()}",
|
||||
CTASMethod.VIEW,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,7 +96,7 @@ def run_sql(
|
||||
test_client,
|
||||
sql,
|
||||
cta=False,
|
||||
ctas_method=CtasMethod.TABLE,
|
||||
ctas_method=CTASMethod.TABLE,
|
||||
tmp_table="tmp",
|
||||
async_=False,
|
||||
):
|
||||
@@ -104,14 +110,14 @@ def run_sql(
|
||||
select_as_cta=cta,
|
||||
tmp_table_name=tmp_table,
|
||||
client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), # noqa: S311
|
||||
ctas_method=ctas_method,
|
||||
ctas_method=ctas_method.name,
|
||||
),
|
||||
).json
|
||||
|
||||
|
||||
def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
|
||||
def drop_table_if_exists(table_name: str, table_type: CTASMethod) -> None:
|
||||
"""Drop table if it exists, works on any DB"""
|
||||
sql = f"DROP {table_type} IF EXISTS {table_name}"
|
||||
sql = f"DROP {table_type.name} IF EXISTS {table_name}"
|
||||
database = get_example_database()
|
||||
with database.get_sqla_engine() as engine:
|
||||
engine.execute(sql)
|
||||
@@ -124,10 +130,10 @@ def quote_f(value: Optional[str]):
|
||||
return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
|
||||
|
||||
|
||||
def cta_result(ctas_method: CtasMethod):
|
||||
def cta_result(ctas_method: CTASMethod):
|
||||
if backend() != "presto":
|
||||
return [], []
|
||||
if ctas_method == CtasMethod.TABLE:
|
||||
if ctas_method == CTASMethod.TABLE:
|
||||
return [{"rows": 1}], [{"name": "rows", "type": "BIGINT", "is_dttm": False}]
|
||||
return [{"result": True}], [{"name": "result", "type": "BOOLEAN", "is_dttm": False}]
|
||||
|
||||
@@ -143,13 +149,13 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW])
|
||||
def test_run_sync_query_dont_exist(test_client, ctas_method):
|
||||
examples_db = get_example_database()
|
||||
engine_name = examples_db.db_engine_spec.engine_name
|
||||
sql_dont_exist = "SELECT name FROM table_dont_exist"
|
||||
result = run_sql(test_client, sql_dont_exist, cta=True, ctas_method=ctas_method)
|
||||
if backend() == "sqlite" and ctas_method == CtasMethod.VIEW:
|
||||
if backend() == "sqlite" and ctas_method == CTASMethod.VIEW:
|
||||
assert QueryStatus.SUCCESS == result["status"], result
|
||||
elif backend() == "presto":
|
||||
assert (
|
||||
@@ -188,9 +194,9 @@ def test_run_sync_query_dont_exist(test_client, ctas_method):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
def test_run_sync_query_cta(test_client, ctas_method):
|
||||
tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}"
|
||||
@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW])
|
||||
def test_run_sync_query_cta(test_client, ctas_method: CTASMethod) -> None:
|
||||
tmp_table_name = f"{TEST_SYNC}_{ctas_method.name.lower()}"
|
||||
result = run_sql(
|
||||
test_client, QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method
|
||||
)
|
||||
@@ -218,16 +224,44 @@ def test_run_sync_query_cta_no_data(test_client):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
@pytest.mark.parametrize(
|
||||
"ctas_method, expected",
|
||||
[
|
||||
(
|
||||
CTASMethod.TABLE,
|
||||
"""
|
||||
CREATE TABLE sqllab_test_db.test_sync_cta_table AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
CTASMethod.VIEW,
|
||||
"""
|
||||
CREATE VIEW sqllab_test_db.test_sync_cta_view AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
@mock.patch( # noqa: PT008
|
||||
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
|
||||
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
|
||||
)
|
||||
def test_run_sync_query_cta_config(test_client, ctas_method):
|
||||
def test_run_sync_query_cta_config(
|
||||
test_client,
|
||||
ctas_method: CTASMethod,
|
||||
expected: str,
|
||||
) -> None:
|
||||
if backend() == "sqlite":
|
||||
# sqlite doesn't support schemas
|
||||
return
|
||||
tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}"
|
||||
tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.name.lower()}"
|
||||
result = run_sql(
|
||||
test_client, QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name
|
||||
)
|
||||
@@ -235,10 +269,7 @@ def test_run_sync_query_cta_config(test_client, ctas_method):
|
||||
assert cta_result(ctas_method) == (result["data"], result["columns"])
|
||||
|
||||
query = get_query_by_id(result["query"]["serverId"])
|
||||
assert (
|
||||
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}"
|
||||
== query.executed_sql
|
||||
)
|
||||
assert query.executed_sql == expected
|
||||
assert query.select_sql == get_select_star(
|
||||
tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME
|
||||
)
|
||||
@@ -249,16 +280,44 @@ def test_run_sync_query_cta_config(test_client, ctas_method):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
@pytest.mark.parametrize(
|
||||
"ctas_method, expected",
|
||||
[
|
||||
(
|
||||
CTASMethod.TABLE,
|
||||
"""
|
||||
CREATE TABLE sqllab_test_db.test_async_cta_config_table AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
CTASMethod.VIEW,
|
||||
"""
|
||||
CREATE VIEW sqllab_test_db.test_async_cta_config_view AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
@mock.patch( # noqa: PT008
|
||||
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
|
||||
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
|
||||
)
|
||||
def test_run_async_query_cta_config(test_client, ctas_method):
|
||||
def test_run_async_query_cta_config(
|
||||
test_client,
|
||||
ctas_method: CTASMethod,
|
||||
expected: str,
|
||||
) -> None:
|
||||
if backend() == "sqlite":
|
||||
# sqlite doesn't support schemas
|
||||
return
|
||||
tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}"
|
||||
tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.name.lower()}"
|
||||
result = run_sql(
|
||||
test_client,
|
||||
QUERY,
|
||||
@@ -275,18 +334,43 @@ def test_run_async_query_cta_config(test_client, ctas_method):
|
||||
get_select_star(tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME)
|
||||
== query.select_sql
|
||||
)
|
||||
assert (
|
||||
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}"
|
||||
== query.executed_sql
|
||||
)
|
||||
assert query.executed_sql == expected
|
||||
|
||||
delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
def test_run_async_cta_query(test_client, ctas_method):
|
||||
table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}"
|
||||
@pytest.mark.parametrize(
|
||||
"ctas_method, expected",
|
||||
[
|
||||
(
|
||||
CTASMethod.TABLE,
|
||||
"""
|
||||
CREATE TABLE test_async_cta_table AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
CTASMethod.VIEW,
|
||||
"""
|
||||
CREATE VIEW test_async_cta_view AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_async_cta_query(
|
||||
test_client,
|
||||
ctas_method: CTASMethod,
|
||||
expected: str,
|
||||
) -> None:
|
||||
table_name = f"{TEST_ASYNC_CTA}_{ctas_method.name.lower()}"
|
||||
result = run_sql(
|
||||
test_client,
|
||||
QUERY,
|
||||
@@ -301,7 +385,7 @@ def test_run_async_cta_query(test_client, ctas_method):
|
||||
assert QueryStatus.SUCCESS == query.status
|
||||
assert get_select_star(table_name, query.limit) in query.select_sql
|
||||
|
||||
assert f"CREATE {ctas_method} {table_name} AS \n{QUERY}" == query.executed_sql
|
||||
assert query.executed_sql == expected
|
||||
assert QUERY == query.sql
|
||||
assert query.rows == (1 if backend() == "presto" else 0)
|
||||
assert query.select_as_cta
|
||||
@@ -311,9 +395,37 @@ def test_run_async_cta_query(test_client, ctas_method):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
|
||||
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
|
||||
tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}"
|
||||
@pytest.mark.parametrize(
|
||||
"ctas_method, expected",
|
||||
[
|
||||
(
|
||||
CTASMethod.TABLE,
|
||||
"""
|
||||
CREATE TABLE test_async_lower_limit_table AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
CTASMethod.VIEW,
|
||||
"""
|
||||
CREATE VIEW test_async_lower_limit_view AS
|
||||
SELECT
|
||||
name
|
||||
FROM birth_names
|
||||
LIMIT 1
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_async_cta_query_with_lower_limit(
|
||||
test_client,
|
||||
ctas_method: CTASMethod,
|
||||
expected: str,
|
||||
) -> None:
|
||||
tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.name.lower()}"
|
||||
result = run_sql(
|
||||
test_client,
|
||||
QUERY,
|
||||
@@ -332,7 +444,7 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
|
||||
else get_select_star(tmp_table, query.limit)
|
||||
)
|
||||
|
||||
assert f"CREATE {ctas_method} {tmp_table} AS \n{QUERY}" == query.executed_sql
|
||||
assert query.executed_sql == expected
|
||||
assert QUERY == query.sql
|
||||
|
||||
assert query.rows == (1 if backend() == "presto" else 0)
|
||||
@@ -442,28 +554,6 @@ def test_msgpack_payload_serialization():
|
||||
assert isinstance(serialized, bytes)
|
||||
|
||||
|
||||
def test_create_table_as():
|
||||
q = ParsedQuery("SELECT * FROM outer_space;")
|
||||
|
||||
assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp")
|
||||
assert (
|
||||
"DROP TABLE IF EXISTS tmp;\nCREATE TABLE tmp AS \nSELECT * FROM outer_space"
|
||||
== q.as_create_table("tmp", overwrite=True)
|
||||
)
|
||||
|
||||
# now without a semicolon
|
||||
q = ParsedQuery("SELECT * FROM outer_space")
|
||||
assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp")
|
||||
|
||||
# now a multi-line query
|
||||
multi_line_query = "SELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'"
|
||||
q = ParsedQuery(multi_line_query)
|
||||
assert (
|
||||
"CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'"
|
||||
== q.as_create_table("tmp")
|
||||
)
|
||||
|
||||
|
||||
def test_in_app_context():
|
||||
@celery_app.task(bind=True)
|
||||
def my_task(self):
|
||||
@@ -484,8 +574,8 @@ def test_in_app_context():
|
||||
)
|
||||
|
||||
|
||||
def delete_tmp_view_or_table(name: str, db_object_type: str):
|
||||
db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}")
|
||||
def delete_tmp_view_or_table(name: str, ctas_method: CTASMethod):
|
||||
db.get_engine().execute(f"DROP {ctas_method.name} IF EXISTS {name}")
|
||||
|
||||
|
||||
def wait_for_success(result):
|
||||
|
||||
@@ -31,16 +31,15 @@ from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetErrorException
|
||||
from superset.exceptions import SupersetErrorException, SupersetInvalidCVASException
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.sql.parse import CTASMethod
|
||||
from superset.sql_lab import (
|
||||
cancel_query,
|
||||
execute_sql_statements,
|
||||
apply_limit_if_exists,
|
||||
)
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.utils.core import backend
|
||||
from superset.utils import json
|
||||
from superset.utils.json import datetime_to_epoch # noqa: F401
|
||||
@@ -132,31 +131,13 @@ class TestSqlLab(SupersetTestCase):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
data = self.run_sql("DELETE FROM birth_names", "1")
|
||||
assert data == {
|
||||
"errors": [
|
||||
{
|
||||
"message": (
|
||||
"This database does not allow for DDL/DML, and the query "
|
||||
"could not be parsed to confirm it is a read-only query. Please " # noqa: E501
|
||||
"contact your administrator for more assistance."
|
||||
),
|
||||
"error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR,
|
||||
"level": ErrorLevel.ERROR,
|
||||
"extra": {
|
||||
"issue_codes": [
|
||||
{
|
||||
"code": 1022,
|
||||
"message": "Issue 1022 - Database does not allow data manipulation.", # noqa: E501
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
assert (
|
||||
data["errors"][0]["error_type"] == SupersetErrorType.DML_NOT_ALLOWED_ERROR
|
||||
)
|
||||
|
||||
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
|
||||
@parameterized.expand([CTASMethod.TABLE, CTASMethod.VIEW])
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_sql_json_cta_dynamic_db(self, ctas_method):
|
||||
def test_sql_json_cta_dynamic_db(self, ctas_method: CTASMethod) -> None:
|
||||
examples_db = get_example_database()
|
||||
if examples_db.backend == "sqlite":
|
||||
# sqlite doesn't support database creation
|
||||
@@ -170,7 +151,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
examples_db.allow_ctas = True # enable cta
|
||||
|
||||
self.login(ADMIN_USERNAME)
|
||||
tmp_table_name = f"test_target_{ctas_method.lower()}"
|
||||
tmp_table_name = f"test_target_{ctas_method.name.lower()}"
|
||||
self.run_sql(
|
||||
"SELECT * FROM birth_names",
|
||||
"1",
|
||||
@@ -195,7 +176,9 @@ class TestSqlLab(SupersetTestCase):
|
||||
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
|
||||
|
||||
# cleanup
|
||||
engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
|
||||
engine.execute(
|
||||
f"DROP {ctas_method.name} admin_database.{tmp_table_name}"
|
||||
)
|
||||
examples_db.allow_ctas = old_allow_ctas
|
||||
db.session.commit()
|
||||
|
||||
@@ -608,10 +591,10 @@ class TestSqlLab(SupersetTestCase):
|
||||
|
||||
@mock.patch("superset.sql_lab.db")
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
@mock.patch("superset.sql_lab.execute_query")
|
||||
def test_execute_sql_statements(
|
||||
self,
|
||||
mock_execute_sql_statement,
|
||||
mock_execute_query,
|
||||
mock_get_query,
|
||||
mock_db,
|
||||
):
|
||||
@@ -623,7 +606,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
"""
|
||||
)
|
||||
mock_db = mock.MagicMock() # noqa: F841
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query = mock.MagicMock(select_as_cta=False)
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
|
||||
@@ -641,30 +624,20 @@ class TestSqlLab(SupersetTestCase):
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
mock_execute_sql_statement.assert_has_calls(
|
||||
mock_execute_query.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
"-- comment\nSET @value = 42",
|
||||
mock_query,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
mock.call(
|
||||
"SELECT /*+ hint */ @value AS foo",
|
||||
mock_query,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
mock.call(mock_query, mock_cursor, None),
|
||||
mock.call(mock_query, mock_cursor, None),
|
||||
]
|
||||
)
|
||||
|
||||
@mock.patch("superset.sql_lab.results_backend", None)
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
@mock.patch("superset.sql_lab.execute_query")
|
||||
def test_execute_sql_statements_no_results_backend(
|
||||
self, mock_execute_sql_statement, mock_get_query
|
||||
self,
|
||||
mock_execute_query,
|
||||
mock_get_query,
|
||||
):
|
||||
sql = dedent(
|
||||
"""
|
||||
@@ -712,10 +685,10 @@ class TestSqlLab(SupersetTestCase):
|
||||
|
||||
@mock.patch("superset.sql_lab.db")
|
||||
@mock.patch("superset.sql_lab.get_query")
|
||||
@mock.patch("superset.sql_lab.execute_sql_statement")
|
||||
@mock.patch("superset.sql_lab.execute_query")
|
||||
def test_execute_sql_statements_ctas(
|
||||
self,
|
||||
mock_execute_sql_statement,
|
||||
mock_execute_query,
|
||||
mock_get_query,
|
||||
mock_db,
|
||||
):
|
||||
@@ -727,7 +700,13 @@ class TestSqlLab(SupersetTestCase):
|
||||
"""
|
||||
)
|
||||
mock_db = mock.MagicMock() # noqa: F841
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query = mock.MagicMock(
|
||||
select_as_cta=True,
|
||||
ctas_method=CTASMethod.TABLE.name,
|
||||
tmp_table_name="table",
|
||||
tmp_schema_name="schema",
|
||||
catalog="catalog",
|
||||
)
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
|
||||
@@ -738,7 +717,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
|
||||
# set the query to CTAS
|
||||
mock_query.select_as_cta = True
|
||||
mock_query.ctas_method = CtasMethod.TABLE
|
||||
mock_query.ctas_method = CTASMethod.TABLE.name
|
||||
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
@@ -749,22 +728,10 @@ class TestSqlLab(SupersetTestCase):
|
||||
expand_data=False,
|
||||
log_params=None,
|
||||
)
|
||||
mock_execute_sql_statement.assert_has_calls(
|
||||
mock_execute_query.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
"-- comment\nSET @value = 42",
|
||||
mock_query,
|
||||
mock_cursor,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
mock.call(
|
||||
"SELECT /*+ hint */ @value AS foo",
|
||||
mock_query,
|
||||
mock_cursor,
|
||||
None,
|
||||
True, # apply_ctas
|
||||
),
|
||||
mock.call(mock_query, mock_cursor, None),
|
||||
mock.call(mock_query, mock_cursor, None),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -795,7 +762,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
)
|
||||
|
||||
# try invalid CVAS
|
||||
mock_query.ctas_method = CtasMethod.VIEW
|
||||
mock_query.ctas_method = CTASMethod.VIEW.name
|
||||
sql = dedent(
|
||||
"""
|
||||
-- comment
|
||||
@@ -803,7 +770,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
SELECT /*+ hint */ @value AS foo;
|
||||
"""
|
||||
)
|
||||
with pytest.raises(SupersetErrorException) as excinfo:
|
||||
with pytest.raises(SupersetInvalidCVASException) as excinfo:
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query=sql,
|
||||
@@ -870,29 +837,6 @@ class TestSqlLab(SupersetTestCase):
|
||||
]
|
||||
}
|
||||
|
||||
def test_apply_limit_if_exists_when_incremented_limit_is_none(self):
|
||||
sql = """
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
"""
|
||||
database = get_example_database()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.limit = 300
|
||||
final_sql = apply_limit_if_exists(database, None, mock_query, sql)
|
||||
|
||||
assert final_sql == sql
|
||||
|
||||
def test_apply_limit_if_exists_when_increased_limit(self):
|
||||
sql = """
|
||||
SET @value = 42;
|
||||
SELECT @value AS foo;
|
||||
"""
|
||||
database = get_example_database()
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.limit = 300
|
||||
final_sql = apply_limit_if_exists(database, 1000, mock_query, sql)
|
||||
assert "LIMIT 1000" in final_sql
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", [HiveEngineSpec, PrestoEngineSpec])
|
||||
def test_cancel_query_implicit(spec: BaseEngineSpec) -> None:
|
||||
|
||||
Reference in New Issue
Block a user