# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from unittest.mock import patch import pytest from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session # noqa: F401 from superset import db from superset.commands.explore.form_data.state import TemporaryExploreState from superset.connectors.sqla.models import SqlaTable from superset.explore.exceptions import DatasetAccessDeniedError from superset.extensions import cache_manager from superset.models.slice import Slice from superset.utils import json from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, # noqa: F401 load_world_bank_data, # noqa: F401 ) from tests.integration_tests.test_app import app FORM_DATA_KEY = "form_data_key" FORM_DATA = {"test": "test value"} @pytest.fixture def chart_id(load_world_bank_dashboard_with_slices) -> int: # noqa: F811 with app.app_context(): # noqa: F841 chart = db.session.query(Slice).filter_by(slice_name="World's Population").one() return chart.id @pytest.fixture def admin_id() -> int: with app.app_context(): # noqa: F841 admin = db.session.query(User).filter_by(username="admin").one() return admin.id @pytest.fixture def dataset() -> int: with app.app_context(): # noqa: F841 dataset = ( db.session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) return dataset @pytest.fixture(autouse=True) def cache(chart_id, admin_id, dataset): entry: TemporaryExploreState = { "owner": admin_id, "datasource_id": dataset.id, "datasource_type": dataset.type, "chart_id": chart_id, "form_data": json.dumps(FORM_DATA), } cache_manager.explore_form_data_cache.set(FORM_DATA_KEY, entry) # partially match the dataset using the most important attributes def assert_dataset(result, dataset_id): dataset = result["dataset"] assert dataset["id"] == dataset_id assert dataset["datasource_name"] == "wb_health_population" assert dataset["is_sqllab_view"] is False # noqa: E712 assert dataset["main_dttm_col"] == "year" assert dataset["sql"] is None # noqa: E711 assert dataset["type"] == "table" assert dataset["uid"] == f"{dataset_id}__table" # partially match the slice using the most important attributes def assert_slice(result, chart_id, dataset_id): slice = result["slice"] assert slice["edit_url"] == f"/chart/edit/{chart_id}" assert slice["is_managed_externally"] is False # noqa: E712 assert slice["slice_id"] == chart_id assert slice["slice_name"] == "World's Population" assert slice["form_data"]["datasource"] == f"{dataset_id}__table" assert slice["form_data"]["viz_type"] == "big_number" def test_no_params_provided(test_client, login_as_admin): resp = test_client.get("api/v1/explore/") # noqa: F541 assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert result["dataset"]["name"] == "[Missing Dataset]" assert result["form_data"]["datasource"] == "None__table" assert result["message"] is None # noqa: E711 assert result["slice"] is None # noqa: E711 def test_get_from_cache(test_client, login_as_admin, dataset): resp = test_client.get( f"api/v1/explore/?form_data_key={FORM_DATA_KEY}&datasource_id={dataset.id}&datasource_type={dataset.type}" ) assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert_dataset(result, dataset.id) assert result["form_data"]["datasource"] == f"{dataset.id}__table" assert result["form_data"]["test"] == "test value" assert result["message"] is None # noqa: E711 assert result["slice"] is None # noqa: E711 def test_get_from_cache_unknown_key_chart_id( test_client, login_as_admin, chart_id, dataset ): unknown_key = "unknown_key" resp = test_client.get( f"api/v1/explore/?form_data_key={unknown_key}&slice_id={chart_id}" ) assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert_dataset(result, dataset.id) assert_slice(result, chart_id, dataset.id) assert result["form_data"]["datasource"] == f"{dataset.id}__table" assert ( result["message"] == "Form data not found in cache, reverting to chart metadata." ) def test_get_from_cache_unknown_key_dataset(test_client, login_as_admin, dataset): unknown_key = "unknown_key" resp = test_client.get( f"api/v1/explore/?form_data_key={unknown_key}&datasource_id={dataset.id}&datasource_type={dataset.type}" ) assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert_dataset(result, dataset.id) assert result["form_data"]["datasource"] == f"{dataset.id}__table" assert ( result["message"] == "Form data not found in cache, reverting to dataset metadata." ) assert result["slice"] is None # noqa: E711 def test_get_from_cache_unknown_key_no_extra_parameters(test_client, login_as_admin): unknown_key = "unknown_key" resp = test_client.get(f"api/v1/explore/?form_data_key={unknown_key}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert result["dataset"]["name"] == "[Missing Dataset]" assert result["form_data"]["datasource"] == "None__table" assert result["message"] is None # noqa: E711 assert result["slice"] is None # noqa: E711 def test_get_from_permalink(test_client, login_as_admin, chart_id, dataset): form_data = { "chart_id": chart_id, "datasource": f"{dataset.id}__{dataset.type}", **FORM_DATA, } resp = test_client.post("api/v1/explore/permalink", json={"formData": form_data}) # noqa: F541 data = json.loads(resp.data.decode("utf-8")) permalink_key = data["key"] resp = test_client.get(f"api/v1/explore/?permalink_key={permalink_key}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert_dataset(result, dataset.id) assert result["form_data"]["datasource"] == f"{dataset.id}__table" assert result["form_data"]["test"] == "test value" assert result["message"] is None # noqa: E711 assert result["slice"] is None # noqa: E711 def test_get_from_permalink_unknown_key(test_client, login_as_admin): unknown_key = "unknown_key" resp = test_client.get(f"api/v1/explore/?permalink_key={unknown_key}") assert resp.status_code == 404 @patch("superset.security.SupersetSecurityManager.can_access_datasource") def test_get_dataset_access_denied_with_form_data_key( mock_can_access_datasource, test_client, login_as_admin, dataset ): message = "Dataset access denied" mock_can_access_datasource.side_effect = DatasetAccessDeniedError( message=message, datasource_id=dataset.id, datasource_type=dataset.type ) resp = test_client.get( f"api/v1/explore/?form_data_key={FORM_DATA_KEY}&datasource_id={dataset.id}&datasource_type={dataset.type}" ) data = json.loads(resp.data.decode("utf-8")) assert resp.status_code == 403 assert data["datasource_id"] == dataset.id assert data["datasource_type"] == dataset.type assert data["message"] == message @patch("superset.security.SupersetSecurityManager.raise_for_access") def test_get_dataset_access_denied( mock_raise_for_access, test_client, login_as_admin, dataset ): message = "Dataset access denied" mock_raise_for_access.side_effect = DatasetAccessDeniedError( message=message, datasource_id=dataset.id, datasource_type=dataset.type ) resp = test_client.get( f"api/v1/explore/?datasource_id={dataset.id}&datasource_type={dataset.type}" ) data = json.loads(resp.data.decode("utf-8")) assert resp.status_code == 403 assert data["datasource_id"] == dataset.id assert data["datasource_type"] == dataset.type assert data["message"] == message @patch("superset.daos.datasource.DatasourceDAO.get_datasource") def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset): dataset.default_endpoint = "another_endpoint" mock_get_datasource.return_value = dataset resp = test_client.get( f"api/v1/explore/?datasource_id={dataset.id}&datasource_type={dataset.type}" ) data = json.loads(resp.data.decode("utf-8")) assert resp.status_code == 302 assert data["redirect"] == dataset.default_endpoint def test_get_url_params(test_client, login_as_admin, chart_id): resp = test_client.get(f"api/v1/explore/?slice_id={chart_id}&foo=bar") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert result["form_data"]["url_params"] == { "foo": "bar", "slice_id": str(chart_id), } def test_granularity_sqla_override_updates_temporal_range_filter_subject( test_client, login_as_admin, chart_id, admin_id, dataset ): """ Test that extra_form_data.granularity_sqla overrides TEMPORAL_RANGE filter subject. The flow is: 1. Chart has TEMPORAL_RANGE adhoc filters on various columns 2. Dashboard applies Time Column native filter selecting 'year' via extra_form_data 3. Explore API processes form_data through merge_extra_filters/merge_extra_form_data 4. All TEMPORAL_RANGE filter subjects should be updated to 'year' """ form_data_with_temporal_filter = { "datasource": f"{dataset.id}__{dataset.type}", "viz_type": "country_map", "time_range": "Last week", "adhoc_filters": [ { "clause": "WHERE", "comparator": "No filter", "expressionType": "SIMPLE", "operator": "TEMPORAL_RANGE", "subject": "some_other_time_col", }, { "clause": "WHERE", "comparator": "foo", "expressionType": "SIMPLE", "operator": "==", "subject": "non_temporal_col", }, { "clause": "WHERE", "comparator": "Last year", "expressionType": "SIMPLE", "operator": "TEMPORAL_RANGE", "subject": "another_time_col", }, ], "extra_form_data": { "granularity_sqla": "year", "time_range": "Last month", }, } test_form_data_key = f"test_granularity_override_key_{chart_id}_{dataset.id}" entry: TemporaryExploreState = { "owner": admin_id, "datasource_id": dataset.id, "datasource_type": dataset.type, "chart_id": chart_id, "form_data": json.dumps(form_data_with_temporal_filter), } cache_manager.explore_form_data_cache.set(test_form_data_key, entry) try: resp = test_client.get( f"api/v1/explore/?form_data_key={test_form_data_key}" f"&datasource_id={dataset.id}&datasource_type={dataset.type}" ) assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") form_data = result["form_data"] adhoc_filters = form_data.get("adhoc_filters", []) temporal_range_filters = [ f for f in adhoc_filters if f.get("operator") == "TEMPORAL_RANGE" and f.get("expressionType") == "SIMPLE" ] assert len(temporal_range_filters) == 2, "Expected two TEMPORAL_RANGE filters" for temporal_filter in temporal_range_filters: assert temporal_filter["subject"] == "year", ( "Time Column native filter (granularity_sqla) should override " "TEMPORAL_RANGE filter subject for all matching filters" ) non_temporal_filters = [f for f in adhoc_filters if f.get("operator") == "=="] assert len(non_temporal_filters) == 1 assert non_temporal_filters[0]["subject"] == "non_temporal_col" assert form_data["time_range"] == "Last month" assert form_data.get("granularity") == "year" finally: cache_manager.explore_form_data_cache.delete(test_form_data_key)