refactor(TestChartApi): move chart data api tests into TestChartDataApi (#17407)

* refactor charts api tests

* move new added test

* refactor charts api tests
This commit is contained in:
ofekisr
2021-11-14 23:35:23 +02:00
committed by GitHub
parent 0ca4312212
commit d8851c9a89
3 changed files with 852 additions and 864 deletions

View File

@@ -17,20 +17,10 @@
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
from datetime import datetime
from io import BytesIO
from typing import Optional, List
from unittest import mock
from zipfile import is_zipfile, ZipFile
from tests.integration_tests.conftest import with_feature_flags
from superset.models.sql_lab import Query
from tests.integration_tests.insert_chart_mixin import InsertChartMixin
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
import humanize
import prison
import pytest
@@ -38,36 +28,23 @@ import yaml
from sqlalchemy import and_
from sqlalchemy.sql import func
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
)
from tests.integration_tests.test_app import app
from superset import security_manager
from superset.charts.commands.data import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
from superset.extensions import async_query_manager, cache_manager, db
from superset.models.annotations import AnnotationLayer
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import cache_manager, db
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils.core import (
AnnotationType,
get_example_database,
get_example_default_schema,
get_main_database,
AdhocMetricExpressionType,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.utils.core import get_example_default_schema
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import (
SupersetTestCase,
post_assert_metric,
test_client,
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.insert_chart_mixin import InsertChartMixin
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
)
from tests.integration_tests.fixtures.importexport import (
chart_config,
chart_metadata_config,
@@ -75,17 +52,13 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config,
dataset_metadata_config,
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
)
from tests.integration_tests.fixtures.query_context import (
get_query_context,
ANNOTATION_LAYERS,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_slice,
)
from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
)
from tests.integration_tests.test_app import app
from tests.integration_tests.utils.get_dashboards import get_dashboards_ids
CHART_DATA_URI = "api/v1/chart/data"
@@ -1067,641 +1040,6 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_simple(self):
"""
Chart data API: Test chart data query
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
expected_row_count = self.get_expected_row_count("client_id_1")
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_get_no_query_context(self):
"""
Chart data API: Test GET endpoint when query context is null
"""
self.login(username="admin")
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": "Chart has no query context saved. Please save the chart again."
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_get(self):
"""
Chart data API: Test GET endpoint
"""
self.login(username="admin")
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
chart.query_context = json.dumps(
{
"datasource": {"id": chart.table.id, "type": "table"},
"force": False,
"queries": [
{
"time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00",
"granularity": "ds",
"filters": [],
"extras": {
"time_range_endpoints": ["inclusive", "exclusive"],
"having": "",
"having_druid": [],
"where": "",
},
"applied_time_extras": {},
"columns": ["gender"],
"metrics": ["sum__num"],
"orderby": [["sum__num", False]],
"annotation_layers": [],
"row_limit": 50000,
"timeseries_limit": 0,
"order_desc": True,
"url_params": {},
"custom_params": {},
"custom_form_data": {},
}
],
"result_format": "json",
"result_type": "full",
}
)
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["status"] == "success"
assert data["result"][0]["rowcount"] == 2
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_applied_time_extras(self):
"""
Chart data API: Test chart data query with applied time extras
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["applied_time_extras"] = {
"__time_range": "100 years ago : now",
"__time_origin": "now",
}
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"][0]["applied_filters"],
[
{"column": "gender"},
{"column": "num"},
{"column": "name"},
{"column": "__time_range"},
],
)
self.assertEqual(
data["result"][0]["rejected_filters"],
[{"column": "__time_origin", "reason": "not_druid_datasource"},],
)
expected_row_count = self.get_expected_row_count("client_id_2")
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_limit_offset(self):
"""
Chart data API: Test chart data query with limit and offset
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 5
request_payload["queries"][0]["row_offset"] = 0
request_payload["queries"][0]["orderby"] = [["name", True]]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
# TODO: fix offset for presto DB
if get_example_database().backend == "presto":
return
# ensure that offset works properly
offset = 2
expected_name = result["data"][offset]["name"]
request_payload["queries"][0]["row_offset"] = offset
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
self.assertEqual(result["data"][0]["name"], expected_name)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7},
)
def test_chart_data_default_row_limit(self):
"""
Chart data API: Ensure row count doesn't exceed default limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 7)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
)
def test_chart_data_sql_max_row_limit(self):
"""
Chart data API: Ensure row count doesn't exceed max global row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_sample_default_limit(self):
"""
Chart data API: Ensure sample response row count defaults to config defaults
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config",
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
)
def test_chart_data_sample_custom_limit(self):
"""
Chart data API: Ensure requested sample response row count is between
default and SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
)
def test_chart_data_sql_max_row_sample_limit(self):
"""
Chart data API: Ensure requested sample response row count doesn't
exceed SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
def test_chart_data_incorrect_result_type(self):
"""
Chart data API: Test chart data with unsupported result type
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_incorrect_result_format(self):
"""
Chart data API: Test chart data with unsupported result format
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_format"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_invalid_form_data(self):
"""
Chart data API: Test chart data with invalid form_data json
"""
self.login(username="admin")
data = {"form_data": "NOT VALID JSON"}
rv = self.client.post(
CHART_DATA_URI, data=data, content_type="multipart/form-data"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response["message"], "Request is not JSON")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_query_result_type(self):
"""
Chart data API: Test chart data with query result format
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_csv_result_format(self):
"""
Chart data API: Test chart data with CSV result format
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
# Test chart csv without permission
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_csv_result_format_permission_denined(self):
"""
Chart data API: Test chart data with CSV result format
"""
self.login(username="gamma_no_csv")
request_payload = get_query_context("birth_names")
request_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 403)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_mixed_case_filter_op(self):
"""
Chart data API: Ensure mixed case filter operator generates valid result
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"][0]["op"] = "In"
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@unittest.skip("Failing due to timezone difference")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_dttm_filter(self):
"""
Chart data API: Ensure temporal column filter converts epoch to dttm expression
"""
table = self.get_birth_names_dataset()
if table.database.backend == "presto":
# TODO: date handling on Presto not fully in line with other engine specs
return
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["time_range"] = ""
dttm = self.get_dttm()
ms_epoch = dttm.timestamp() * 1000
request_payload["queries"][0]["filters"][0] = {
"col": "ds",
"op": "!=",
"val": ms_epoch,
}
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
# assert that unconverted timestamp is not present in query
assert str(ms_epoch) not in result["query"]
# assert that converted timestamp is present in query where supported
dttm_col: Optional[TableColumn] = None
for col in table.columns:
if col.column_name == table.main_dttm_col:
dttm_col = col
if dttm_col:
dttm_expression = table.database.db_engine_spec.convert_dttm(
dttm_col.type, dttm,
)
self.assertIn(dttm_expression, result["query"])
else:
raise Exception("ds column not found")
def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("prophet")
self.login(username="admin")
request_payload = get_query_context("birth_names")
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
request_payload["queries"][0]["granularity"] = "ds"
request_payload["queries"][0]["post_processing"] = [
{
"operation": "prophet",
"options": {
"time_grain": time_grain,
"periods": 3,
"confidence_interval": 0.9,
},
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_query_missing_filter(self):
"""
Chart data API: Ensure filter referencing missing column is ignored
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = [
{"col": "non_existent_filter", "op": "==", "val": "foo"},
]
request_payload["result_type"] = ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
assert "non_existent_filter" not in response_payload["result"][0]["query"]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_no_data(self):
"""
Chart data API: Test chart data with empty result
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "foo"}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 0)
self.assertEqual(result["data"], [])
def test_chart_data_incorrect_request(self):
"""
Chart data API: Test chart data with invalid SQL
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = []
# erroneus WHERE-clause
request_payload["queries"][0]["extras"]["where"] = "(gender abc def)"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_with_invalid_datasource(self):
"""
Chart data API: Test chart data query with invalid schema
"""
self.login(username="admin")
payload = get_query_context("birth_names")
payload["datasource"] = "abc"
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_with_invalid_enum_value(self):
"""
Chart data API: Test chart data query with invalid enum value
"""
self.login(username="admin")
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_range_endpoints"] = [
"abc",
"EXCLUSIVE",
]
rv = self.client.post(CHART_DATA_URI, json=payload)
self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
"""
Chart data API: Test chart data query not allowed
"""
self.login(username="gamma")
payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 401)
response_payload = json.loads(rv.data.decode("utf-8"))
assert (
response_payload["errors"][0]["error_type"]
== SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_jinja_filter_request(self):
"""
Chart data API: Ensure request referencing filters via jinja renders a correct query
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.QUERY
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "boy"}
]
request_payload["queries"][0]["extras"][
"where"
] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]["query"]
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_async(self):
"""
Chart data API: Test chart data query (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
request_payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 202)
data = json.loads(rv.data.decode("utf-8"))
keys = list(data.keys())
self.assertCountEqual(
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
)
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_async_cached_sync_response(self):
"""
Chart data API: Test chart data query returns results synchronously
when results are already cached.
"""
async_query_manager.init_app(app)
self.login(username="admin")
class QueryContext:
result_format = ChartDataResultFormat.JSON
result_type = ChartDataResultType.FULL
cmd_run_val = {
"query_context": QueryContext(),
"queries": [{"query": "select * from foo"}],
}
with mock.patch.object(
ChartDataCommand, "run", return_value=cmd_run_val
) as patched_run:
request_payload = get_query_context("birth_names")
request_payload["result_type"] = ChartDataResultType.FULL
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
patched_run.assert_called_once_with(force_cached=True)
self.assertEqual(data, {"result": [{"query": "select * from foo"}]})
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_async_results_type(self):
"""
Chart data API: Test chart data query non-JSON format (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = "results"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_async_invalid_token(self):
"""
Chart data API: Test chart data query (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
request_payload = get_query_context("birth_names")
test_client.set_cookie(
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
)
rv = test_client.post(CHART_DATA_URI, json=request_payload)
self.assertEqual(rv.status_code, 401)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
def test_chart_data_cache(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request
"""
async_query_manager.init_app(app)
self.login(username="admin")
query_context = get_query_context("birth_names")
cache_loader.load.return_value = query_context
orig_run = ChartDataCommand.run
def mock_run(self, **kwargs):
assert kwargs["force_cached"] == True
# override force_cached to get result from DB
return orig_run(self, force_cached=False)
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
data = json.loads(rv.data.decode("utf-8"))
expected_row_count = self.get_expected_row_count("client_id_3")
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_cache_run_failed(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request with run failure
"""
async_query_manager.init_app(app)
self.login(username="admin")
query_context = get_query_context("birth_names")
cache_loader.load.return_value = query_context
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(data["message"], "Error loading data from cache")
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_cache_no_login(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request (no login)
"""
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
cache_loader.load.return_value = query_context
orig_run = ChartDataCommand.run
def mock_run(self, **kwargs):
assert kwargs["force_cached"] == True
# override force_cached to get result from DB
return orig_run(self, force_cached=False)
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
rv = self.client.get(f"{CHART_DATA_URI}/test-cache-key",)
self.assertEqual(rv.status_code, 401)
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
def test_chart_data_cache_key_error(self):
"""
Chart data cache API: Test chart data async cache request with invalid cache key
"""
async_query_manager.init_app(app)
self.login(username="admin")
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
self.assertEqual(rv.status_code, 404)
def test_export_chart(self):
"""
Chart API: Test export chart
@@ -1902,191 +1240,3 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
}
]
}
@pytest.mark.usefixtures(
"create_annotation_layers", "load_birth_names_dashboard_with_slices"
)
def test_chart_data_annotations(self):
"""
Chart data API: Test chart data query
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
annotation_layers = []
request_payload["queries"][0]["annotation_layers"] = annotation_layers
# formula
annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA])
# interval
interval_layer = (
db.session.query(AnnotationLayer)
.filter(AnnotationLayer.name == "name1")
.one()
)
interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL]
interval["value"] = interval_layer.id
annotation_layers.append(interval)
# event
event_layer = (
db.session.query(AnnotationLayer)
.filter(AnnotationLayer.name == "name2")
.one()
)
event = ANNOTATION_LAYERS[AnnotationType.EVENT]
event["value"] = event_layer.id
annotation_layers.append(event)
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
# response should only contain interval and event data, not formula
self.assertEqual(len(data["result"][0]["annotation_data"]), 2)
def get_expected_row_count(self, client_id: str) -> int:
start_date = datetime.now()
start_date = start_date.replace(
year=start_date.year - 100, hour=0, minute=0, second=0
)
quoted_table_name = self.quote_name("birth_names")
sql = f"""
SELECT COUNT(*) AS rows_count FROM (
SELECT name AS name, SUM(num) AS sum__num
FROM {quoted_table_name}
WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}'
AND gender = 'boy'
GROUP BY name
ORDER BY sum__num DESC
LIMIT 100) AS inner__query
"""
resp = self.run_sql(sql, client_id, raise_on_error=True)
db.session.query(Query).delete()
db.session.commit()
return resp["data"][0]["rows_count"]
def quote_name(self, name: str):
if get_main_database().backend in {"presto", "hive"}:
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
name
)
return name
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_rowcount(self):
"""
Chart data API: Query total rows
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["is_rowcount"] = True
request_payload["queries"][0]["groupby"] = ["name"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
expected_row_count = self.get_expected_row_count("client_id_4")
self.assertEqual(result["data"][0]["rowcount"], expected_row_count)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_timegrains(self):
"""
Chart data API: Query timegrains and columns
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"] = [
{"result_type": ChartDataResultType.TIMEGRAINS},
{"result_type": ChartDataResultType.COLUMNS},
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
timegrain_result = response_payload["result"][0]
column_result = response_payload["result"][1]
assert list(timegrain_result["data"][0].keys()) == [
"name",
"function",
"duration",
]
assert list(column_result["data"][0].keys()) == [
"column_name",
"verbose_name",
"dtype",
]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_series_limit(self):
"""
Chart data API: Query total rows
"""
SERIES_LIMIT = 5
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["columns"] = ["state", "name"]
request_payload["queries"][0]["series_columns"] = ["name"]
request_payload["queries"][0]["series_limit"] = SERIES_LIMIT
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
data = response_payload["result"][0]["data"]
unique_names = set(row["name"] for row in data)
self.maxDiff = None
self.assertEqual(len(unique_names), SERIES_LIMIT)
self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_virtual_table_with_colons(self):
"""
Chart data API: test query with literal colon characters in query, metrics,
where clause and filters
"""
self.login(username="admin")
owner = self.get_user("admin").id
user = db.session.query(security_manager.user_model).get(owner)
table = SqlaTable(
table_name="virtual_table_1",
schema=get_example_default_schema(),
owners=[user],
database=get_example_database(),
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()
request_payload = get_query_context("birth_names")
request_payload["datasource"] = {
"type": "table",
"id": table.id,
}
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
request_payload["queries"][0]["orderby"] = None
request_payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
"label": "count",
}
]
request_payload["queries"][0]["filters"] = [
{"col": "foo", "op": "!=", "val": ":qwerty:",}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
db.session.delete(table)
db.session.commit()
assert rv.status_code == 200
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]