From 7c98e266ce6bd39fe8db22d8e25af711012bc9f2 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 17 Oct 2022 10:40:42 +0100 Subject: [PATCH] chore(sqla): refactor query utils (#21811) Co-authored-by: Ville Brofeldt --- superset/connectors/sqla/models.py | 26 +++-- superset/connectors/sqla/utils.py | 7 ++ superset/models/core.py | 11 ++- .../charts/data/api_tests.py | 71 ++++++++++++- tests/integration_tests/conftest.py | 99 ++++++++++++++++++- tests/integration_tests/sqla_models_tests.py | 2 +- tests/integration_tests/test_app.py | 20 +++- 7 files changed, 221 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 46975303284..dbcfb80eb37 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -83,6 +83,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampEx from superset.exceptions import ( QueryClauseValidationException, QueryObjectValidationError, + SupersetSecurityException, ) from superset.jinja_context import ( BaseTemplateProcessor, @@ -514,19 +515,19 @@ def _process_sql_expression( expression: Optional[str], database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Optional[str]: if template_processor and expression: expression = template_processor.process_template(expression) if expression: - expression = validate_adhoc_subquery( - expression, - database_id, - schema, - ) try: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: + except (QueryClauseValidationException, SupersetSecurityException) as ex: raise QueryObjectValidationError(ex.message) from ex return expression @@ -1465,6 +1466,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho msg=ex.message, ) ) from ex + where = _process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: @@ -1477,7 +1483,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho msg=ex.message, ) ) from ex + having = _process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + ) having_clause_and += [self.text(having)] + if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index a2b54201d6b..5359c9e2149 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -22,6 +22,7 @@ import sqlparse from flask_babel import lazy_gettext as _ from sqlalchemy import and_, inspect, or_ from sqlalchemy.engine import Engine +from sqlalchemy.engine.url import URL as SqlaURL from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import Session from sqlalchemy.sql.type_api import TypeEngine @@ -37,6 +38,7 @@ from superset.result_set import SupersetResultSet from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table from superset.superset_typing import ResultSetColumnType from superset.tables.models import Table as NewTable +from superset.utils.memoized import memoized if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -252,3 +254,8 @@ def load_or_create_tables( # pylint: disable=too-many-arguments existing.add((table.schema, table.table)) return new_tables + + +@memoized +def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: + return SqlaURL(drivername=drivername).get_dialect()().identifier_preparer.quote diff --git a/superset/models/core.py b/superset/models/core.py index fcc7cf16d8e..97e1d763b94 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -21,7 +21,7 @@ import json import logging import textwrap from ast import literal_eval -from contextlib import closing +from contextlib import closing, contextmanager from copy import deepcopy from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type @@ -345,6 +345,15 @@ class Database( effective_username = g.user.username return effective_username + @contextmanager + def get_sqla_engine_with_context( + self, + schema: Optional[str] = None, + nullpool: bool = True, + source: Optional[utils.QuerySource] = None, + ) -> Engine: + yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + @memoized( watch=( "impersonate_user", diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index e8f258421f9..212c9d01af4 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -21,7 +21,7 @@ import unittest import copy from datetime import datetime from io import BytesIO -from typing import Optional +from typing import Optional, Dict, Any from unittest import mock from zipfile import ZipFile @@ -974,3 +974,72 @@ class TestGetChartDataApi(BaseTestChartDataApi): unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] + + +@pytest.fixture() +def physical_query_context(physical_dataset) -> Dict[str, Any]: + return { + "datasource": { + "type": physical_dataset.type, + "id": physical_dataset.id, + }, + "queries": [ + { + "columns": ["col1"], + "metrics": ["count"], + "orderby": [["col1", True]], + } + ], + "result_type": ChartDataResultType.FULL, + "force": True, + } + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_not_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (200, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index fee13c8950a..c605819ee64 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -21,13 +21,15 @@ from typing import Any, Callable, Generator, Optional, TYPE_CHECKING from unittest.mock import patch import pytest +from flask.ctx import AppContext +from flask.testing import FlaskClient from sqlalchemy.engine import Engine from superset import db from superset.extensions import feature_flag_manager from superset.utils.core import json_dumps_w_dates from superset.utils.database import get_example_database, remove_database -from tests.integration_tests.test_app import app +from tests.integration_tests.test_app import app, login if TYPE_CHECKING: from superset.connectors.sqla.models import Database @@ -42,6 +44,29 @@ def app_context(): yield +@pytest.fixture +def test_client(app_context: AppContext): + with app.test_client() as client: + yield client + + +@pytest.fixture +def login_as(test_client: "FlaskClient[Any]"): + """Fixture with app context and logged in admin user.""" + + def _login_as(username: str, password: str = "general"): + login(test_client, username=username, password=password) + + yield _login_as + # no need to log out as both app_context and test_client are + # function level fixtures anyway + + +@pytest.fixture +def login_as_admin(login_as: Callable[..., None]): + yield login_as("admin") + + @pytest.fixture(autouse=True, scope="session") def setup_sample_data() -> Any: # TODO(john-bodley): Determine a cleaner way of setting up the sample data without @@ -180,3 +205,75 @@ def with_feature_flags(**mock_feature_flags): return functools.update_wrapper(wrapper, test_fn) return decorate + + +@pytest.fixture +def physical_dataset(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.connectors.sqla.utils import get_identifier_quoter + + example_database = get_example_database() + + with example_database.get_sqla_engine_with_context() as engine: + quoter = get_identifier_quoter(engine.name) + # sqlite can only execute one statement at a time + engine.execute( + f""" + CREATE TABLE IF NOT EXISTS physical_dataset( + col1 INTEGER, + col2 VARCHAR(255), + col3 DECIMAL(4,2), + col4 VARCHAR(255), + col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' + ); + """ + ) + engine.execute( + """ + INSERT INTO physical_dataset values + (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), + (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), + (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), + (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), + (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), + (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), + (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), + (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), + (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), + (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); + """ + ) + + dataset = SqlaTable( + table_name="physical_dataset", + database=example_database, + ) + TableColumn(column_name="col1", type="INTEGER", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) + TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset) + TableColumn(column_name="col6", type="TIMESTAMP", is_dttm=True, table=dataset) + TableColumn( + column_name="time column with spaces", + type="TIMESTAMP", + is_dttm=True, + table=dataset, + ) + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.merge(dataset) + db.session.commit() + + yield dataset + + engine.execute( + """ + DROP TABLE physical_dataset; + """ + ) + dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all() + for ds in dataset: + db.session.delete(ds) + db.session.commit() diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 8990243c6b4..f06836720dc 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -262,7 +262,7 @@ class TestDatabaseModel(SupersetTestCase): ) db.session.commit() - with pytest.raises(SupersetSecurityException): + with pytest.raises(QueryObjectValidationError): table.get_sqla_query(**base_query_obj) # Cleanup db.session.delete(table) diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py index 798f3e9cda2..fb7b47b67cb 100644 --- a/tests/integration_tests/test_app.py +++ b/tests/integration_tests/test_app.py @@ -14,11 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import TYPE_CHECKING -""" -Here is where we create the app which ends up being shared across all tests.integration_tests. A future -optimization will be to create a separate app instance for each test class. -""" from superset.app import create_app +if TYPE_CHECKING: + from typing import Any + + from flask.testing import FlaskClient + app = create_app() + + +def login( + client: "FlaskClient[Any]", username: str = "admin", password: str = "general" +): + resp = client.post( + "/login/", + data=dict(username=username, password=password), + ).get_data(as_text=True) + assert "User confirmation needed" not in resp