mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
refactor(TestChartApi): move chart data api tests into TestChartDataApi (#17407)
* refactor charts api tests * move new added test * refactor charts api tests
This commit is contained in:
@@ -17,20 +17,10 @@
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional, List
|
||||
from unittest import mock
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
from tests.integration_tests.conftest import with_feature_flags
|
||||
from superset.models.sql_lab import Query
|
||||
from tests.integration_tests.insert_chart_mixin import InsertChartMixin
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
)
|
||||
|
||||
import humanize
|
||||
import prison
|
||||
import pytest
|
||||
@@ -38,36 +28,23 @@ import yaml
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
load_world_bank_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset import security_manager
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.extensions import async_query_manager, cache_manager, db
|
||||
from superset.models.annotations import AnnotationLayer
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.extensions import cache_manager, db
|
||||
from superset.models.core import Database, FavStar, FavStarClassName
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.reports import ReportSchedule, ReportScheduleType
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import (
|
||||
AnnotationType,
|
||||
get_example_database,
|
||||
get_example_default_schema,
|
||||
get_main_database,
|
||||
AdhocMetricExpressionType,
|
||||
)
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.utils.core import get_example_default_schema
|
||||
|
||||
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
|
||||
from tests.integration_tests.base_tests import (
|
||||
SupersetTestCase,
|
||||
post_assert_metric,
|
||||
test_client,
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.insert_chart_mixin import InsertChartMixin
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.fixtures.energy_dashboard import (
|
||||
load_energy_table_with_slice,
|
||||
)
|
||||
|
||||
from tests.integration_tests.fixtures.importexport import (
|
||||
chart_config,
|
||||
chart_metadata_config,
|
||||
@@ -75,17 +52,13 @@ from tests.integration_tests.fixtures.importexport import (
|
||||
dataset_config,
|
||||
dataset_metadata_config,
|
||||
)
|
||||
from tests.integration_tests.fixtures.energy_dashboard import (
|
||||
load_energy_table_with_slice,
|
||||
)
|
||||
from tests.integration_tests.fixtures.query_context import (
|
||||
get_query_context,
|
||||
ANNOTATION_LAYERS,
|
||||
)
|
||||
from tests.integration_tests.fixtures.unicode_dashboard import (
|
||||
load_unicode_dashboard_with_slice,
|
||||
)
|
||||
from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
load_world_bank_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
from tests.integration_tests.utils.get_dashboards import get_dashboards_ids
|
||||
|
||||
CHART_DATA_URI = "api/v1/chart/data"
|
||||
@@ -1067,641 +1040,6 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["count"], 0)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_simple(self):
|
||||
"""
|
||||
Chart data API: Test chart data query
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
expected_row_count = self.get_expected_row_count("client_id_1")
|
||||
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_get_no_query_context(self):
|
||||
"""
|
||||
Chart data API: Test GET endpoint when query context is null
|
||||
"""
|
||||
self.login(username="admin")
|
||||
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
|
||||
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data == {
|
||||
"message": "Chart has no query context saved. Please save the chart again."
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_get(self):
|
||||
"""
|
||||
Chart data API: Test GET endpoint
|
||||
"""
|
||||
self.login(username="admin")
|
||||
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
|
||||
chart.query_context = json.dumps(
|
||||
{
|
||||
"datasource": {"id": chart.table.id, "type": "table"},
|
||||
"force": False,
|
||||
"queries": [
|
||||
{
|
||||
"time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00",
|
||||
"granularity": "ds",
|
||||
"filters": [],
|
||||
"extras": {
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"having": "",
|
||||
"having_druid": [],
|
||||
"where": "",
|
||||
},
|
||||
"applied_time_extras": {},
|
||||
"columns": ["gender"],
|
||||
"metrics": ["sum__num"],
|
||||
"orderby": [["sum__num", False]],
|
||||
"annotation_layers": [],
|
||||
"row_limit": 50000,
|
||||
"timeseries_limit": 0,
|
||||
"order_desc": True,
|
||||
"url_params": {},
|
||||
"custom_params": {},
|
||||
"custom_form_data": {},
|
||||
}
|
||||
],
|
||||
"result_format": "json",
|
||||
"result_type": "full",
|
||||
}
|
||||
)
|
||||
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["result"][0]["status"] == "success"
|
||||
assert data["result"][0]["rowcount"] == 2
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_applied_time_extras(self):
|
||||
"""
|
||||
Chart data API: Test chart data query with applied time extras
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["applied_time_extras"] = {
|
||||
"__time_range": "100 years ago : now",
|
||||
"__time_origin": "now",
|
||||
}
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(
|
||||
data["result"][0]["applied_filters"],
|
||||
[
|
||||
{"column": "gender"},
|
||||
{"column": "num"},
|
||||
{"column": "name"},
|
||||
{"column": "__time_range"},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
data["result"][0]["rejected_filters"],
|
||||
[{"column": "__time_origin", "reason": "not_druid_datasource"},],
|
||||
)
|
||||
expected_row_count = self.get_expected_row_count("client_id_2")
|
||||
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_limit_offset(self):
|
||||
"""
|
||||
Chart data API: Test chart data query with limit and offset
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["row_limit"] = 5
|
||||
request_payload["queries"][0]["row_offset"] = 0
|
||||
request_payload["queries"][0]["orderby"] = [["name", True]]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
# TODO: fix offset for presto DB
|
||||
if get_example_database().backend == "presto":
|
||||
return
|
||||
|
||||
# ensure that offset works properly
|
||||
offset = 2
|
||||
expected_name = result["data"][offset]["name"]
|
||||
request_payload["queries"][0]["row_offset"] = offset
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
self.assertEqual(result["data"][0]["name"], expected_name)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
|
||||
)
|
||||
def test_chart_data_default_row_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure row count doesn't exceed default limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 7)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
|
||||
)
|
||||
def test_chart_data_sql_max_row_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure row count doesn't exceed max global row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
)
|
||||
def test_chart_data_sample_default_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count defaults to config defaults
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.SAMPLES
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_actions.config",
|
||||
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
|
||||
)
|
||||
def test_chart_data_sample_custom_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count is between
|
||||
default and SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.SAMPLES
|
||||
request_payload["queries"][0]["row_limit"] = 10
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
|
||||
)
|
||||
def test_chart_data_sql_max_row_sample_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count doesn't
|
||||
exceed SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.SAMPLES
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
def test_chart_data_incorrect_result_type(self):
|
||||
"""
|
||||
Chart data API: Test chart data with unsupported result type
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = "qwerty"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_incorrect_result_format(self):
|
||||
"""
|
||||
Chart data API: Test chart data with unsupported result format
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_format"] = "qwerty"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_invalid_form_data(self):
|
||||
"""
|
||||
Chart data API: Test chart data with invalid form_data json
|
||||
"""
|
||||
self.login(username="admin")
|
||||
data = {"form_data": "NOT VALID JSON"}
|
||||
|
||||
rv = self.client.post(
|
||||
CHART_DATA_URI, data=data, content_type="multipart/form-data"
|
||||
)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response["message"], "Request is not JSON")
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_query_result_type(self):
|
||||
"""
|
||||
Chart data API: Test chart data with query result format
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.QUERY
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_csv_result_format(self):
|
||||
"""
|
||||
Chart data API: Test chart data with CSV result format
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_format"] = "csv"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
# Test chart csv without permission
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_csv_result_format_permission_denined(self):
|
||||
"""
|
||||
Chart data API: Test chart data with CSV result format
|
||||
"""
|
||||
self.login(username="gamma_no_csv")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_format"] = "csv"
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 403)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_mixed_case_filter_op(self):
|
||||
"""
|
||||
Chart data API: Ensure mixed case filter operator generates valid result
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["filters"][0]["op"] = "In"
|
||||
request_payload["queries"][0]["row_limit"] = 10
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@unittest.skip("Failing due to timezone difference")
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_dttm_filter(self):
|
||||
"""
|
||||
Chart data API: Ensure temporal column filter converts epoch to dttm expression
|
||||
"""
|
||||
table = self.get_birth_names_dataset()
|
||||
if table.database.backend == "presto":
|
||||
# TODO: date handling on Presto not fully in line with other engine specs
|
||||
return
|
||||
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["time_range"] = ""
|
||||
dttm = self.get_dttm()
|
||||
ms_epoch = dttm.timestamp() * 1000
|
||||
request_payload["queries"][0]["filters"][0] = {
|
||||
"col": "ds",
|
||||
"op": "!=",
|
||||
"val": ms_epoch,
|
||||
}
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
|
||||
# assert that unconverted timestamp is not present in query
|
||||
assert str(ms_epoch) not in result["query"]
|
||||
|
||||
# assert that converted timestamp is present in query where supported
|
||||
dttm_col: Optional[TableColumn] = None
|
||||
for col in table.columns:
|
||||
if col.column_name == table.main_dttm_col:
|
||||
dttm_col = col
|
||||
if dttm_col:
|
||||
dttm_expression = table.database.db_engine_spec.convert_dttm(
|
||||
dttm_col.type, dttm,
|
||||
)
|
||||
self.assertIn(dttm_expression, result["query"])
|
||||
else:
|
||||
raise Exception("ds column not found")
|
||||
|
||||
def test_chart_data_prophet(self):
|
||||
"""
|
||||
Chart data API: Ensure prophet post transformation works
|
||||
"""
|
||||
pytest.importorskip("prophet")
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
time_grain = "P1Y"
|
||||
request_payload["queries"][0]["is_timeseries"] = True
|
||||
request_payload["queries"][0]["groupby"] = []
|
||||
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
|
||||
request_payload["queries"][0]["granularity"] = "ds"
|
||||
request_payload["queries"][0]["post_processing"] = [
|
||||
{
|
||||
"operation": "prophet",
|
||||
"options": {
|
||||
"time_grain": time_grain,
|
||||
"periods": 3,
|
||||
"confidence_interval": 0.9,
|
||||
},
|
||||
}
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
row = result["data"][0]
|
||||
self.assertIn("__timestamp", row)
|
||||
self.assertIn("sum__num", row)
|
||||
self.assertIn("sum__num__yhat", row)
|
||||
self.assertIn("sum__num__yhat_upper", row)
|
||||
self.assertIn("sum__num__yhat_lower", row)
|
||||
self.assertEqual(result["rowcount"], 47)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_query_missing_filter(self):
|
||||
"""
|
||||
Chart data API: Ensure filter referencing missing column is ignored
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["filters"] = [
|
||||
{"col": "non_existent_filter", "op": "==", "val": "foo"},
|
||||
]
|
||||
request_payload["result_type"] = ChartDataResultType.QUERY
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
assert "non_existent_filter" not in response_payload["result"][0]["query"]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_no_data(self):
|
||||
"""
|
||||
Chart data API: Test chart data with empty result
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["filters"] = [
|
||||
{"col": "gender", "op": "==", "val": "foo"}
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 0)
|
||||
self.assertEqual(result["data"], [])
|
||||
|
||||
def test_chart_data_incorrect_request(self):
|
||||
"""
|
||||
Chart data API: Test chart data with invalid SQL
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["filters"] = []
|
||||
# erroneus WHERE-clause
|
||||
request_payload["queries"][0]["extras"]["where"] = "(gender abc def)"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_chart_data_with_invalid_datasource(self):
|
||||
"""
|
||||
Chart data API: Test chart data query with invalid schema
|
||||
"""
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["datasource"] = "abc"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_chart_data_with_invalid_enum_value(self):
|
||||
"""
|
||||
Chart data API: Test chart data query with invalid enum value
|
||||
"""
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["queries"][0]["extras"]["time_range_endpoints"] = [
|
||||
"abc",
|
||||
"EXCLUSIVE",
|
||||
]
|
||||
rv = self.client.post(CHART_DATA_URI, json=payload)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_query_exec_not_allowed(self):
|
||||
"""
|
||||
Chart data API: Test chart data query not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
payload = get_query_context("birth_names")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
assert (
|
||||
response_payload["errors"][0]["error_type"]
|
||||
== SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_jinja_filter_request(self):
|
||||
"""
|
||||
Chart data API: Ensure request referencing filters via jinja renders a correct query
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.QUERY
|
||||
request_payload["queries"][0]["filters"] = [
|
||||
{"col": "gender", "op": "==", "val": "boy"}
|
||||
]
|
||||
request_payload["queries"][0]["extras"][
|
||||
"where"
|
||||
] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]["query"]
|
||||
if get_example_database().backend != "presto":
|
||||
assert "('boy' = 'boy')" in result
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_async(self):
|
||||
"""
|
||||
Chart data API: Test chart data query (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
keys = list(data.keys())
|
||||
self.assertCountEqual(
|
||||
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
|
||||
)
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_async_cached_sync_response(self):
|
||||
"""
|
||||
Chart data API: Test chart data query returns results synchronously
|
||||
when results are already cached.
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
|
||||
class QueryContext:
|
||||
result_format = ChartDataResultFormat.JSON
|
||||
result_type = ChartDataResultType.FULL
|
||||
|
||||
cmd_run_val = {
|
||||
"query_context": QueryContext(),
|
||||
"queries": [{"query": "select * from foo"}],
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
ChartDataCommand, "run", return_value=cmd_run_val
|
||||
) as patched_run:
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = ChartDataResultType.FULL
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
patched_run.assert_called_once_with(force_cached=True)
|
||||
self.assertEqual(data, {"result": [{"query": "select * from foo"}]})
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_async_results_type(self):
|
||||
"""
|
||||
Chart data API: Test chart data query non-JSON format (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = "results"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_async_invalid_token(self):
|
||||
"""
|
||||
Chart data API: Test chart data query (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
test_client.set_cookie(
|
||||
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
|
||||
)
|
||||
rv = test_client.post(CHART_DATA_URI, json=request_payload)
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
|
||||
def test_chart_data_cache(self, cache_loader):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
query_context = get_query_context("birth_names")
|
||||
cache_loader.load.return_value = query_context
|
||||
orig_run = ChartDataCommand.run
|
||||
|
||||
def mock_run(self, **kwargs):
|
||||
assert kwargs["force_cached"] == True
|
||||
# override force_cached to get result from DB
|
||||
return orig_run(self, force_cached=False)
|
||||
|
||||
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
expected_row_count = self.get_expected_row_count("client_id_3")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_cache_run_failed(self, cache_loader):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request with run failure
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
query_context = get_query_context("birth_names")
|
||||
cache_loader.load.return_value = query_context
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(data["message"], "Error loading data from cache")
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_cache_no_login(self, cache_loader):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request (no login)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
query_context = get_query_context("birth_names")
|
||||
cache_loader.load.return_value = query_context
|
||||
orig_run = ChartDataCommand.run
|
||||
|
||||
def mock_run(self, **kwargs):
|
||||
assert kwargs["force_cached"] == True
|
||||
# override force_cached to get result from DB
|
||||
return orig_run(self, force_cached=False)
|
||||
|
||||
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
|
||||
rv = self.client.get(f"{CHART_DATA_URI}/test-cache-key",)
|
||||
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
|
||||
def test_chart_data_cache_key_error(self):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request with invalid cache key
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_export_chart(self):
|
||||
"""
|
||||
Chart API: Test export chart
|
||||
@@ -1902,191 +1240,3 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"create_annotation_layers", "load_birth_names_dashboard_with_slices"
|
||||
)
|
||||
def test_chart_data_annotations(self):
|
||||
"""
|
||||
Chart data API: Test chart data query
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
|
||||
annotation_layers = []
|
||||
request_payload["queries"][0]["annotation_layers"] = annotation_layers
|
||||
|
||||
# formula
|
||||
annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA])
|
||||
|
||||
# interval
|
||||
interval_layer = (
|
||||
db.session.query(AnnotationLayer)
|
||||
.filter(AnnotationLayer.name == "name1")
|
||||
.one()
|
||||
)
|
||||
interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL]
|
||||
interval["value"] = interval_layer.id
|
||||
annotation_layers.append(interval)
|
||||
|
||||
# event
|
||||
event_layer = (
|
||||
db.session.query(AnnotationLayer)
|
||||
.filter(AnnotationLayer.name == "name2")
|
||||
.one()
|
||||
)
|
||||
event = ANNOTATION_LAYERS[AnnotationType.EVENT]
|
||||
event["value"] = event_layer.id
|
||||
annotation_layers.append(event)
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
# response should only contain interval and event data, not formula
|
||||
self.assertEqual(len(data["result"][0]["annotation_data"]), 2)
|
||||
|
||||
def get_expected_row_count(self, client_id: str) -> int:
|
||||
start_date = datetime.now()
|
||||
start_date = start_date.replace(
|
||||
year=start_date.year - 100, hour=0, minute=0, second=0
|
||||
)
|
||||
|
||||
quoted_table_name = self.quote_name("birth_names")
|
||||
sql = f"""
|
||||
SELECT COUNT(*) AS rows_count FROM (
|
||||
SELECT name AS name, SUM(num) AS sum__num
|
||||
FROM {quoted_table_name}
|
||||
WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}'
|
||||
AND gender = 'boy'
|
||||
GROUP BY name
|
||||
ORDER BY sum__num DESC
|
||||
LIMIT 100) AS inner__query
|
||||
"""
|
||||
resp = self.run_sql(sql, client_id, raise_on_error=True)
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
return resp["data"][0]["rows_count"]
|
||||
|
||||
def quote_name(self, name: str):
|
||||
if get_main_database().backend in {"presto", "hive"}:
|
||||
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
|
||||
name
|
||||
)
|
||||
return name
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_rowcount(self):
|
||||
"""
|
||||
Chart data API: Query total rows
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["is_rowcount"] = True
|
||||
request_payload["queries"][0]["groupby"] = ["name"]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
expected_row_count = self.get_expected_row_count("client_id_4")
|
||||
self.assertEqual(result["data"][0]["rowcount"], expected_row_count)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_timegrains(self):
|
||||
"""
|
||||
Chart data API: Query timegrains and columns
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"] = [
|
||||
{"result_type": ChartDataResultType.TIMEGRAINS},
|
||||
{"result_type": ChartDataResultType.COLUMNS},
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
timegrain_result = response_payload["result"][0]
|
||||
column_result = response_payload["result"][1]
|
||||
assert list(timegrain_result["data"][0].keys()) == [
|
||||
"name",
|
||||
"function",
|
||||
"duration",
|
||||
]
|
||||
assert list(column_result["data"][0].keys()) == [
|
||||
"column_name",
|
||||
"verbose_name",
|
||||
"dtype",
|
||||
]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_series_limit(self):
|
||||
"""
|
||||
Chart data API: Query total rows
|
||||
"""
|
||||
SERIES_LIMIT = 5
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["columns"] = ["state", "name"]
|
||||
request_payload["queries"][0]["series_columns"] = ["name"]
|
||||
request_payload["queries"][0]["series_limit"] = SERIES_LIMIT
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
data = response_payload["result"][0]["data"]
|
||||
unique_names = set(row["name"] for row in data)
|
||||
self.maxDiff = None
|
||||
self.assertEqual(len(unique_names), SERIES_LIMIT)
|
||||
self.assertEqual(
|
||||
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_virtual_table_with_colons(self):
|
||||
"""
|
||||
Chart data API: test query with literal colon characters in query, metrics,
|
||||
where clause and filters
|
||||
"""
|
||||
self.login(username="admin")
|
||||
owner = self.get_user("admin").id
|
||||
user = db.session.query(security_manager.user_model).get(owner)
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="virtual_table_1",
|
||||
schema=get_example_default_schema(),
|
||||
owners=[user],
|
||||
database=get_example_database(),
|
||||
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
|
||||
)
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
table.fetch_metadata()
|
||||
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["datasource"] = {
|
||||
"type": "table",
|
||||
"id": table.id,
|
||||
}
|
||||
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
|
||||
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
|
||||
request_payload["queries"][0]["orderby"] = None
|
||||
request_payload["queries"][0]["metrics"] = [
|
||||
{
|
||||
"expressionType": AdhocMetricExpressionType.SQL,
|
||||
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
|
||||
"label": "count",
|
||||
}
|
||||
]
|
||||
request_payload["queries"][0]["filters"] = [
|
||||
{"col": "foo", "op": "!=", "val": ":qwerty:",}
|
||||
]
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
assert rv.status_code == 200
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
data = result["data"]
|
||||
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
|
||||
# make sure results and query parameters are unescaped
|
||||
assert {row["foo"] for row in data} == {":foo"}
|
||||
assert {row["bar"] for row in data} == {":bar:"}
|
||||
assert "':asdf'" in result["query"]
|
||||
assert "':xyz:qwerty'" in result["query"]
|
||||
assert "':qwerty:'" in result["query"]
|
||||
|
||||
Reference in New Issue
Block a user