diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 7ea5bc6a2a1..94adb59a1cf 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -65,6 +65,7 @@ from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2Redire from superset.sql.parse import ( BaseSQLStatement, LimitMethod, + RLSMethod, SQLScript, SQLStatement, Table, @@ -438,6 +439,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # the `cancel_query` value in the `extra` field of the `query` object has_query_id_before_execute = True + @classmethod + def get_rls_method(cls) -> RLSMethod: + """ + Returns the RLS method to be used for this engine. + + There are two ways to insert RLS: either replacing the table with a subquery + that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is + safer, but not supported in all databases. + """ + return ( + RLSMethod.AS_SUBQUERY + if cls.allows_subqueries and cls.allows_alias_in_select + else RLSMethod.AS_PREDICATE + ) + @classmethod def is_oauth2_enabled(cls) -> bool: return ( diff --git a/superset/models/helpers.py b/superset/models/helpers.py index db05b415848..341585e5da5 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -17,6 +17,8 @@ # pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" +from __future__ import annotations + import builtins import dataclasses import logging @@ -32,7 +34,6 @@ import numpy as np import pandas as pd import pytz import sqlalchemy as sa -import sqlparse import yaml from flask import g from flask_appbuilder import Model @@ -63,15 +64,12 @@ from superset.exceptions import ( ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, - SupersetParseError, SupersetSecurityException, ) from superset.extensions import feature_flag_manager from superset.jinja_context import BaseTemplateProcessor -from superset.sql.parse import SQLScript +from superset.sql.parse import SQLScript, SQLStatement from superset.sql_parse import ( - has_table_query, - insert_rls_in_predicate, sanitize_clause, ) from superset.superset_typing import ( @@ -111,9 +109,10 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] def validate_adhoc_subquery( sql: str, - database_id: int, - engine: str, + database: Database, + catalog: str | None, default_schema: str, + engine: str, ) -> str: """ Check if adhoc SQL contains sub-queries or nested sub-queries with table. @@ -125,28 +124,23 @@ def validate_adhoc_subquery( :raise SupersetSecurityException if sql contains sub-queries or nested sub-queries with table """ - statements = [] - for statement in sqlparse.parse(sql): - try: - has_table = has_table_query(str(statement), engine) - except SupersetParseError: - has_table = True + from superset.sql_lab import apply_rls - if has_table: - if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): - raise SupersetSecurityException( - SupersetError( - error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, - message=_("Custom SQL fields cannot contain sub-queries."), - level=ErrorLevel.ERROR, - ) + parsed_statement = SQLStatement(sql, engine) + if parsed_statement.has_subquery(): + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, ) - # TODO (betodealmeida): reimplement with sqlglot - statement = insert_rls_in_predicate(statement, database_id, default_schema) + ) - statements.append(statement) + # enforce RLS rules in any relevant tables + apply_rls(database, catalog, default_schema, parsed_statement) - return ";\n".join(str(statement) for statement in statements) + return parsed_statement.format() def json_to_dict(json_str: str) -> dict[Any, Any]: @@ -784,7 +778,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def database(self) -> "Database": + def database(self) -> Database: raise NotImplementedError() @property @@ -839,9 +833,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if expression: expression = validate_adhoc_subquery( expression, - database_id, - engine, + self.database, + self.catalog, schema, + engine, ) try: expression = sanitize_clause(expression) @@ -1467,6 +1462,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods extras = extras or {} time_grain = extras.get("time_grain_sqla") + # DB-specifc quoting for identifiers + with self.database.get_sqla_engine() as engine: + quote = engine.dialect.identifier_preparer.quote + template_kwargs = { "columns": columns, "from_dttm": from_dttm.isoformat() if from_dttm else None, @@ -1515,6 +1514,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods columns_by_name: dict[str, "TableColumn"] = { col.column_name: col for col in self.columns } + quoted_columns_by_name = {quote(k): v for k, v in columns_by_name.items()} metrics_by_name: dict[str, "SqlMetric"] = { m.metric_name: m for m in self.metrics @@ -1636,15 +1636,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods else: selected = validate_adhoc_subquery( selected, - self.database_id, - self.database.backend, + self.database, + self.catalog, self.schema, + self.database.db_engine_spec.engine, ) outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor + col=selected, + template_processor=template_processor, ) groupby_all_columns[outer.name] = outer if ( @@ -1658,23 +1660,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods _sql = selected["sqlExpression"] _column_label = selected["label"] elif isinstance(selected, str): - _sql = selected + _sql = quote(selected) _column_label = selected selected = validate_adhoc_subquery( _sql, - self.database_id, - self.database.backend, + self.database, + self.catalog, self.schema, + self.database.db_engine_spec.engine, ) select_exprs.append( self.convert_tbl_column_to_sqla_col( - columns_by_name[selected], + quoted_columns_by_name[selected], template_processor=template_processor, label=_column_label, ) - if isinstance(selected, str) and selected in columns_by_name + if selected in quoted_columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label ) @@ -1989,9 +1992,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and db_engine_spec.allows_hidden_cc_in_orderby and col.name in [select_col.name for select_col in select_exprs] ): - with self.database.get_sqla_engine() as engine: - quote = engine.dialect.identifier_preparer.quote - col = literal_column(quote(col.name)) + col = literal_column(quote(col.name)) direction = sa.asc if ascending else sa.desc qry = qry.order_by(direction(col)) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 0a0f4b3e5c1..73255bef13e 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -462,6 +462,14 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def has_subquery(self) -> bool: + """ + Check if the statement has a subquery. + + :return: True if the statement has a subquery at the top level. + """ + raise NotImplementedError() + def parse_predicate(self, predicate: str) -> InternalRepresentation: """ Parse a predicate string into an AST. @@ -803,6 +811,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): return SQLStatement(ast=create_table, engine=self.engine) + def has_subquery(self) -> bool: + """ + Check if the statement has a subquery. + + :return: True if the statement has a subquery at the top level. + """ + return bool(self._parsed.find(exp.Subquery)) + def parse_predicate(self, predicate: str) -> exp.Expression: """ Parse a predicate string into an AST. diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 08a694c9489..d36a530c26c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -15,15 +15,17 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=consider-using-transaction + +from __future__ import annotations + import dataclasses import logging import sys import uuid -from collections import defaultdict from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Optional, TypeVar, Union +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union import backoff import msgpack @@ -56,10 +58,9 @@ from superset.exceptions import ( SupersetResultsBackendNotConfigureException, ) from superset.extensions import celery_app, event_logger -from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql.parse import BaseSQLStatement, CTASMethod, RLSMethod, SQLScript, Table +from superset.sql.parse import BaseSQLStatement, CTASMethod, SQLScript, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer from superset.utils import json @@ -71,6 +72,9 @@ from superset.utils.core import ( from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing +if TYPE_CHECKING: + from superset.models.core import Database + config = app.config stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] @@ -197,52 +201,47 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query) -def apply_rls(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None: +def apply_rls( + database: Database, + catalog: str | None, + schema: str, + parsed_statement: BaseSQLStatement[Any], +) -> None: """ Modify statement inplace to ensure RLS rules are applied. """ - database = query.database - - # we need the default schema to fully qualify the table names - default_schema = database.get_default_schema_for_query(query) - # There are two ways to insert RLS: either replacing the table with a subquery # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is # safer, but not supported in all databases. - method = ( - RLSMethod.AS_SUBQUERY - if database.db_engine_spec.allows_subqueries - and database.db_engine_spec.allows_alias_in_select - else RLSMethod.AS_PREDICATE - ) + method = database.db_engine_spec.get_rls_method() # collect all RLS predicates for all tables in the query - predicates: dict[Table, list[Any]] = defaultdict(list) + predicates: dict[Table, list[Any]] = {} for table in parsed_statement.tables: # fully qualify table table = Table( table.table, - table.schema or default_schema, - table.catalog or query.catalog, + table.schema or schema, + table.catalog or catalog, ) - if table_predicates := get_predicates_for_table( - table, - database, - query.catalog == database.get_default_catalog(), - ): - predicates[table].extend( - parsed_statement.parse_predicate(predicate) - for predicate in table_predicates + predicates[table] = [ + parsed_statement.parse_predicate(predicate) + for predicate in get_predicates_for_table( + table, + database, + database.get_default_catalog(), ) + if predicate + ] - parsed_statement.apply_rls(query.catalog, default_schema, predicates, method) + parsed_statement.apply_rls(catalog, schema, predicates, method) def get_predicates_for_table( table: Table, database: Database, - is_default_catalog: bool, + default_catalog: str | None, ) -> list[str]: """ Get the RLS predicates for a table. @@ -254,7 +253,7 @@ def get_predicates_for_table( # if the dataset in the RLS has null catalog, match it when using the default # catalog catalog_predicate = SqlaTable.catalog == table.catalog - if table.catalog and is_default_catalog: + if table.catalog and table.catalog == default_catalog: catalog_predicate = or_( catalog_predicate, SqlaTable.catalog.is_(None), @@ -483,8 +482,9 @@ def execute_sql_statements( # noqa: C901 raise SupersetDMLNotAllowedException() if is_feature_enabled("RLS_IN_SQLLAB"): + default_schema = query.database.get_default_schema_for_query(query) for statement in parsed_script.statements: - apply_rls(query, statement) + apply_rls(query.database, query.catalog, default_schema, statement) if query.select_as_cta: # CTAS is valid when the last statement is a SELECT, while CVAS is valid when diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index fa5e7b44ba2..59364b4ed04 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -661,7 +661,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - assert rv.status_code == 400 + assert rv.status_code == 422 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_having_parameter_closing_and_comment__400(self): @@ -709,7 +709,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0]["query"] if get_example_database().backend != "presto": - assert "('boy' = 'boy')" in result + assert "(\n 'boy' = 'boy'\n )" in result @unittest.skip("Extremely flaky test on MySQL") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @@ -840,7 +840,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): unique_names = {row["name"] for row in data} self.maxDiff = None assert len(unique_names) == SERIES_LIMIT - assert {column for column in data[0].keys()} == {"state", "name", "sum__num"} # noqa: C416 + assert set(data[0]) == {"state", "name", "sum__num"} @pytest.mark.usefixtures( "create_annotation_layers", "load_birth_names_dashboard_with_slices" @@ -931,7 +931,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): assert rv.status_code == 200 result = rv.json["result"][0] data = result["data"] - assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} # noqa: C416 + assert set(data[0]) == {"foo", "bar", "state", "count"} # make sure results and query parameters are unescaped assert {row["foo"] for row in data} == {":foo"} assert {row["bar"] for row in data} == {":bar:"} @@ -1251,7 +1251,7 @@ class TestGetChartDataApi(BaseTestChartDataApi): response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] - assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416 + assert set(data[0]) == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] @@ -1271,7 +1271,7 @@ class TestGetChartDataApi(BaseTestChartDataApi): response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] - assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416 + assert set(data[0]) == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index c9ba88411de..21162ca52d6 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -568,6 +568,9 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset): + if get_example_database().backend == "sqlite": + return + TableColumn( column_name="DUMMY CC", type="VARCHAR(255)", @@ -702,7 +705,7 @@ def test_get_samples_with_multiple_filters( assert "2000-01-02" in rv.json["result"]["query"] assert "2000-01-04" in rv.json["result"]["query"] assert "col3 = 1.2" in rv.json["result"]["query"] - assert "col4 is null" in rv.json["result"]["query"] + assert "col4 IS NULL" in rv.json["result"]["query"] assert "col2 = 'c'" in rv.json["result"]["query"] diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index f3d73ae5534..499526cba99 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -447,7 +447,11 @@ class TestSqlaTableModel(SupersetTestCase): return None old_inner_join = spec.allows_joins spec.allows_joins = inner_join - arbitrary_gby = "state || gender || '_test'" + arbitrary_gby = ( + "state OR gender OR '_test'" + if get_example_database().backend == "mysql" + else "state || gender || '_test'" + ) arbitrary_metric = dict( # noqa: C408 label="arbitrary", expressionType="SQL", sqlExpression="SUM(num_boys)" ) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index a56e8338352..cb5afc98aff 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -876,12 +876,6 @@ def test_special_chars_in_column_name(app_context, physical_dataset): "columns": [ "col1", "time column with spaces", - { - "label": "I_AM_A_TRUNC_COLUMN", - "sqlExpression": "time column with spaces", - "columnType": "BASE_AXIS", - "timeGrain": "P1Y", - }, ], "metrics": ["count"], "orderby": [["col1", True]], @@ -897,10 +891,8 @@ def test_special_chars_in_column_name(app_context, physical_dataset): if query_object.datasource.database.backend == "sqlite": # sqlite returns string as timestamp column assert df["time column with spaces"][0] == "2002-01-03 00:00:00" - assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00" else: assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03" - assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01" @only_postgresql diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 80ada22e0aa..793e2db6652 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -198,7 +198,7 @@ class TestDatabaseModel(SupersetTestCase): # assert dataset saved metric assert "count('bar_P1D')" in query # assert adhoc metric - assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query + assert "SUM(CASE WHEN user = 'user_abc' THEN 1 ELSE 0 END)" in query # Cleanup db.session.delete(table) db.session.commit() diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index d5ca4e8b100..1bb2cde5f02 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -232,7 +232,6 @@ def test_apply_rls(mocker: MockerFixture) -> None: database.get_default_schema_for_query.return_value = "public" database.get_default_catalog.return_value = "examples" database.db_engine_spec = PostgresEngineSpec - query = mocker.MagicMock(database=database, catalog="examples") get_predicates_for_table = mocker.patch( "superset.sql_lab.get_predicates_for_table", side_effect=[["c1 = 1"], ["c2 = 2"]], @@ -241,12 +240,12 @@ def test_apply_rls(mocker: MockerFixture) -> None: parsed_statement = SQLStatement("SELECT * FROM t1, t2", "postgresql") parsed_statement.tables = sorted(parsed_statement.tables, key=lambda x: x.table) # type: ignore - apply_rls(query, parsed_statement) + apply_rls(database, "examples", "public", parsed_statement) get_predicates_for_table.assert_has_calls( [ - mocker.call(Table("t1", "public", "examples"), database, True), - mocker.call(Table("t2", "public", "examples"), database, True), + mocker.call(Table("t1", "public", "examples"), database, "examples"), + mocker.call(Table("t2", "public", "examples"), database, "examples"), ] ) @@ -285,4 +284,4 @@ def test_get_predicates_for_table(mocker: MockerFixture) -> None: db.session.query().filter().one_or_none.return_value = dataset table = Table("t1", "public", "examples") - assert get_predicates_for_table(table, database, True) == ["c1 = 1"] + assert get_predicates_for_table(table, database, "examples") == ["c1 = 1"]