mirror of
https://github.com/apache/superset.git
synced 2026-04-21 00:54:44 +00:00
feat(sqllab): use sqlglot instead of sqlparse (#33542)
This commit is contained in:
@@ -21,105 +21,46 @@ from unittest import mock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import sqlparse
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, SupersetErrorException
|
||||
from superset.models.core import Database
|
||||
from superset.sql_lab import execute_sql_statements, get_sql_results
|
||||
from superset.utils.core import override_user
|
||||
from superset.sql.parse import SQLStatement, Table
|
||||
from superset.sql_lab import (
|
||||
apply_rls,
|
||||
execute_query,
|
||||
execute_sql_statements,
|
||||
get_predicates_for_table,
|
||||
get_sql_results,
|
||||
)
|
||||
from tests.unit_tests.models.core_test import oauth2_client_info
|
||||
|
||||
|
||||
def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
|
||||
def test_execute_query(mocker: MockerFixture, app: None) -> None:
|
||||
"""
|
||||
Simple test for `execute_sql_statement`.
|
||||
"""
|
||||
from superset.sql_lab import execute_sql_statement
|
||||
|
||||
sql_statement = "SELECT 42 AS answer"
|
||||
|
||||
query = mocker.MagicMock()
|
||||
query.executed_sql = "SELECT 42 AS answer"
|
||||
|
||||
query.limit = 1
|
||||
query.select_as_cta_used = False
|
||||
database = query.database
|
||||
database.allow_dml = False
|
||||
database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2"
|
||||
database.mutate_sql_based_on_config.return_value = "SELECT 42 AS answer LIMIT 2"
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_engine_spec.fetch_data.return_value = [(42,)]
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806
|
||||
|
||||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
apply_ctas=False,
|
||||
)
|
||||
execute_query(query, cursor=cursor, log_params={})
|
||||
|
||||
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
|
||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||
cursor,
|
||||
"SELECT 42 AS answer LIMIT 2",
|
||||
query,
|
||||
)
|
||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||
|
||||
|
||||
def test_execute_sql_statement_with_rls(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test for `execute_sql_statement` when an RLS rule is in place.
|
||||
"""
|
||||
from superset.sql_lab import execute_sql_statement
|
||||
|
||||
sql_statement = "SELECT * FROM sales"
|
||||
sql_statement_with_rls = f"{sql_statement} WHERE organization_id=42"
|
||||
sql_statement_with_rls_and_limit = f"{sql_statement_with_rls} LIMIT 101"
|
||||
|
||||
query = mocker.MagicMock()
|
||||
query.limit = 100
|
||||
query.select_as_cta_used = False
|
||||
database = query.database
|
||||
database.allow_dml = False
|
||||
database.apply_limit_to_sql.return_value = sql_statement_with_rls_and_limit
|
||||
database.mutate_sql_based_on_config.return_value = sql_statement_with_rls_and_limit
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_engine_spec.fetch_data.return_value = [(42,)]
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806
|
||||
mocker.patch(
|
||||
"superset.sql_lab.insert_rls_as_subquery",
|
||||
return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0],
|
||||
)
|
||||
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
|
||||
|
||||
execute_sql_statement(
|
||||
sql_statement,
|
||||
query,
|
||||
cursor=cursor,
|
||||
log_params={},
|
||||
apply_ctas=False,
|
||||
)
|
||||
|
||||
database.apply_limit_to_sql.assert_called_with(
|
||||
"SELECT * FROM sales WHERE organization_id=42",
|
||||
101,
|
||||
force=True,
|
||||
)
|
||||
db_engine_spec.execute_with_cursor.assert_called_with(
|
||||
cursor,
|
||||
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
|
||||
"SELECT 42 AS answer",
|
||||
query,
|
||||
)
|
||||
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
|
||||
@@ -232,122 +173,6 @@ def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> No
|
||||
)
|
||||
|
||||
|
||||
def test_sql_lab_insert_rls_as_subquery(
|
||||
mocker: MockerFixture,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Integration test for `insert_rls_as_subquery`.
|
||||
"""
|
||||
from flask_appbuilder.security.sqla.models import Role, User
|
||||
|
||||
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.security.manager import SupersetSecurityManager
|
||||
from superset.sql_lab import execute_sql_statement
|
||||
from superset.utils.core import RowLevelSecurityFilterType
|
||||
|
||||
engine = db.session.connection().engine
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
connection = engine.raw_connection()
|
||||
connection.execute("CREATE TABLE t (c INTEGER)")
|
||||
for i in range(10):
|
||||
connection.execute("INSERT INTO t VALUES (?)", (i,))
|
||||
|
||||
cursor = connection.cursor()
|
||||
|
||||
query = Query(
|
||||
sql="SELECT c FROM t",
|
||||
client_id="abcde",
|
||||
database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"),
|
||||
schema=None,
|
||||
limit=5,
|
||||
select_as_cta_used=False,
|
||||
)
|
||||
db.session.add(query)
|
||||
db.session.commit()
|
||||
|
||||
admin = User(
|
||||
first_name="Alice",
|
||||
last_name="Doe",
|
||||
email="adoe@example.org",
|
||||
username="admin",
|
||||
roles=[Role(name="Admin")],
|
||||
)
|
||||
|
||||
# first without RLS
|
||||
with override_user(admin):
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
)
|
||||
assert (
|
||||
superset_result_set.to_pandas_df().to_markdown()
|
||||
== """
|
||||
| | c |
|
||||
|---:|----:|
|
||||
| 0 | 0 |
|
||||
| 1 | 1 |
|
||||
| 2 | 2 |
|
||||
| 3 | 3 |
|
||||
| 4 | 4 |""".strip()
|
||||
)
|
||||
assert query.executed_sql == "SELECT\n c\nFROM t\nLIMIT 6"
|
||||
|
||||
# now with RLS
|
||||
rls = RowLevelSecurityFilter(
|
||||
name="sqllab_rls1",
|
||||
filter_type=RowLevelSecurityFilterType.REGULAR,
|
||||
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
|
||||
roles=[admin.roles[0]],
|
||||
group_key=None,
|
||||
clause="c > 5",
|
||||
)
|
||||
db.session.add(rls)
|
||||
db.session.flush()
|
||||
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
|
||||
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
|
||||
|
||||
with override_user(admin):
|
||||
superset_result_set = execute_sql_statement(
|
||||
sql_statement=query.sql,
|
||||
query=query,
|
||||
cursor=cursor,
|
||||
log_params=None,
|
||||
apply_ctas=False,
|
||||
)
|
||||
assert (
|
||||
superset_result_set.to_pandas_df().to_markdown()
|
||||
== """
|
||||
| | c |
|
||||
|---:|----:|
|
||||
| 0 | 6 |
|
||||
| 1 | 7 |
|
||||
| 2 | 8 |
|
||||
| 3 | 9 |""".strip()
|
||||
)
|
||||
assert (
|
||||
query.executed_sql
|
||||
== """SELECT
|
||||
c
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM t
|
||||
WHERE
|
||||
(
|
||||
t.c > 5
|
||||
)
|
||||
) AS t
|
||||
LIMIT 6"""
|
||||
)
|
||||
|
||||
|
||||
@freeze_time("2021-04-01T00:00:00Z")
|
||||
def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
"""
|
||||
@@ -377,8 +202,7 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
"OAuth2 required"
|
||||
)
|
||||
|
||||
query = mocker.MagicMock()
|
||||
query.database = database
|
||||
query = mocker.MagicMock(select_as_cta=False, database=database)
|
||||
mocker.patch("superset.sql_lab.get_query", return_value=query)
|
||||
|
||||
payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
|
||||
@@ -398,3 +222,67 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_apply_rls(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``apply_rls`` helper function.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
database.get_default_catalog.return_value = "examples"
|
||||
database.db_engine_spec = PostgresEngineSpec
|
||||
query = mocker.MagicMock(database=database, catalog="examples")
|
||||
get_predicates_for_table = mocker.patch(
|
||||
"superset.sql_lab.get_predicates_for_table",
|
||||
side_effect=[["c1 = 1"], ["c2 = 2"]],
|
||||
)
|
||||
|
||||
parsed_statement = SQLStatement("SELECT * FROM t1, t2", "postgresql")
|
||||
parsed_statement.tables = sorted(parsed_statement.tables, key=lambda x: x.table) # type: ignore
|
||||
|
||||
apply_rls(query, parsed_statement)
|
||||
|
||||
get_predicates_for_table.assert_has_calls(
|
||||
[
|
||||
mocker.call(Table("t1", "public", "examples"), database, True),
|
||||
mocker.call(Table("t2", "public", "examples"), database, True),
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
parsed_statement.format()
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM t1
|
||||
WHERE
|
||||
c1 = 1
|
||||
) AS t1, (
|
||||
SELECT
|
||||
*
|
||||
FROM t2
|
||||
WHERE
|
||||
c2 = 2
|
||||
) AS t2
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_get_predicates_for_table(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``get_predicates_for_table`` helper function.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
dataset = mocker.MagicMock()
|
||||
predicate = mocker.MagicMock()
|
||||
predicate.compile.return_value = "c1 = 1"
|
||||
dataset.get_sqla_row_level_filters.return_value = [predicate]
|
||||
db = mocker.patch("superset.sql_lab.db")
|
||||
db.session.query().filter().one_or_none.return_value = dataset
|
||||
|
||||
table = Table("t1", "public", "examples")
|
||||
assert get_predicates_for_table(table, database, True) == ["c1 = 1"]
|
||||
|
||||
Reference in New Issue
Block a user