chore: Migrate warm up cache endpoint to api v1 (#23853)

This commit is contained in:
Jack Fragassi
2023-06-20 04:08:29 -07:00
committed by GitHub
parent 3e76736874
commit 5af298e1f6
14 changed files with 704 additions and 66 deletions

View File

@@ -33,6 +33,7 @@ from superset.models.dashboard import Dashboard
from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
@@ -199,7 +200,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.get_assert_metric(uri, "info")
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert set(data["permissions"]) == {"can_read", "can_write", "can_export"}
assert set(data["permissions"]) == {
"can_read",
"can_write",
"can_export",
"can_warm_up_cache",
}
def create_chart_import(self):
buf = BytesIO()
@@ -1682,3 +1688,85 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
assert data["result"][0]["slice_name"] == "name0"
assert data["result"][0]["datasource_id"] == 1
@pytest.mark.usefixtures(
"load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
)
def test_warm_up_cache(self):
self.login()
slc = self.get_slice("Girls", db.session)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
dashboard = self.get_dash_by_slug("births")
rv = self.client.put(
"/api/v1/chart/warm_up_cache",
json={"chart_id": slc.id, "dashboard_id": dashboard.id},
)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
rv = self.client.put(
"/api/v1/chart/warm_up_cache",
json={
"chart_id": slc.id,
"dashboard_id": dashboard.id,
"extra_filters": json.dumps(
[{"col": "name", "op": "in", "val": ["Jennifer"]}]
),
},
)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"],
[{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}],
)
def test_warm_up_cache_chart_id_required(self):
self.login()
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1})
self.assertEqual(rv.status_code, 400)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{"message": {"chart_id": ["Missing data for required field."]}},
)
def test_warm_up_cache_chart_not_found(self):
self.login()
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999})
self.assertEqual(rv.status_code, 404)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data, {"message": "Chart not found"})
def test_warm_up_cache_payload_validation(self):
self.login()
rv = self.client.put(
"/api/v1/chart/warm_up_cache",
json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4},
)
self.assertEqual(rv.status_code, 400)
data = json.loads(rv.data.decode("utf-8"))
print(data)
self.assertEqual(
data,
{
"message": {
"chart_id": ["Not a valid integer."],
"dashboard_id": ["Not a valid integer."],
"extra_filters": ["Not a valid string."],
}
},
)

View File

@@ -23,16 +23,24 @@ from flask import g
from superset import db, security_manager
from superset.charts.commands.create import CreateChartCommand
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.charts.commands.exceptions import (
ChartNotFoundError,
WarmUpCacheChartNotFoundError,
)
from superset.charts.commands.export import ExportChartsCommand
from superset.charts.commands.importers.v1 import ImportChartsCommand
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.commands.warm_up_cache import ChartWarmUpCacheCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_data,
load_energy_table_with_slice,
@@ -442,3 +450,23 @@ class TestChartsUpdateCommand(SupersetTestCase):
assert chart.query_context == query_context
assert len(chart.owners) == 1
assert chart.owners[0] == admin
class TestChartWarmUpCacheCommand(SupersetTestCase):
def test_warm_up_cache_command_chart_not_found(self):
with self.assertRaises(WarmUpCacheChartNotFoundError):
ChartWarmUpCacheCommand(99999, None, None).run()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self):
slc = self.get_slice("Girls", db.session)
result = ChartWarmUpCacheCommand(slc.id, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
)
# can just pass in chart as well
result = ChartWarmUpCacheCommand(slc, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
)

View File

@@ -39,6 +39,7 @@ from superset.datasets.commands.exceptions import DatasetCreateFailedError
from superset.datasets.models import Dataset
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.models.slice import Slice
from superset.utils.core import backend, get_example_default_schema
from superset.utils.database import get_example_database, get_main_database
from superset.utils.dict_import_export import export_to_dict
@@ -514,6 +515,7 @@ class TestDatasetApi(SupersetTestCase):
"can_export",
"can_duplicate",
"can_get_or_create_dataset",
"can_warm_up_cache",
}
def test_create_dataset_item(self):
@@ -2501,3 +2503,117 @@ class TestDatasetApi(SupersetTestCase):
with examples_db.get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE test_create_sqla_table_api")
db.session.commit()
@pytest.mark.usefixtures(
"load_energy_table_with_slice", "load_birth_names_dashboard_with_slices"
)
def test_warm_up_cache(self):
"""
Dataset API: Test warm up cache endpoint
"""
self.login()
energy_table = self.get_energy_usage_dataset()
energy_charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id == energy_table.id, Slice.datasource_type == "table"
)
.all()
)
rv = self.client.put(
"/api/v1/dataset/warm_up_cache",
json={
"table_name": "energy_usage",
"db_name": get_example_database().database_name,
},
)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(energy_charts),
)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
assert "viz_status" in chart_result
# With dashboard id
dashboard = self.get_dash_by_slug("births")
birth_table = self.get_birth_names_dataset()
birth_charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id == birth_table.id, Slice.datasource_type == "table"
)
.all()
)
rv = self.client.put(
"/api/v1/dataset/warm_up_cache",
json={
"table_name": "birth_names",
"db_name": get_example_database().database_name,
"dashboard_id": dashboard.id,
},
)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(birth_charts),
)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
assert "viz_status" in chart_result
# With extra filters
rv = self.client.put(
"/api/v1/dataset/warm_up_cache",
json={
"table_name": "birth_names",
"db_name": get_example_database().database_name,
"dashboard_id": dashboard.id,
"extra_filters": json.dumps(
[{"col": "name", "op": "in", "val": ["Jennifer"]}]
),
},
)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
len(data["result"]),
len(birth_charts),
)
for chart_result in data["result"]:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
assert "viz_status" in chart_result
def test_warm_up_cache_db_and_table_name_required(self):
self.login()
rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1})
self.assertEqual(rv.status_code, 400)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{
"message": {
"db_name": ["Missing data for required field."],
"table_name": ["Missing data for required field."],
}
},
)
def test_warm_up_cache_table_not_found(self):
self.login()
rv = self.client.put(
"/api/v1/dataset/warm_up_cache",
json={"table_name": "not_here", "db_name": "abc"},
)
self.assertEqual(rv.status_code, 404)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data,
{"message": "The provided table was not found in the provided database"},
)

View File

@@ -31,13 +31,20 @@ from superset.datasets.commands.create import CreateDatasetCommand
from superset.datasets.commands.exceptions import (
DatasetInvalidError,
DatasetNotFoundError,
WarmUpCacheTableNotFoundError,
)
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.datasets.commands.importers import v0, v1
from superset.datasets.commands.warm_up_cache import DatasetWarmUpCacheCommand
from superset.models.core import Database
from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_data,
load_energy_table_with_slice,
@@ -575,3 +582,28 @@ class TestCreateDatasetCommand(SupersetTestCase):
with examples_db.get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE test_create_dataset_command")
db.session.commit()
class TestDatasetWarmUpCacheCommand(SupersetTestCase):
def test_warm_up_cache_command_table_not_found(self):
with self.assertRaises(WarmUpCacheTableNotFoundError):
DatasetWarmUpCacheCommand("not", "here", None, None).run()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self):
birth_table = self.get_birth_names_dataset()
birth_charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id == birth_table.id, Slice.datasource_type == "table"
)
.all()
)
results = DatasetWarmUpCacheCommand(
get_example_database().database_name, "birth_names", None, None
).run()
self.assertEqual(len(results), len(birth_charts))
for chart_result in results:
assert "chart_id" in chart_result
assert "viz_error" in chart_result
assert "viz_status" in chart_result

View File

@@ -76,14 +76,11 @@ class TestCacheWarmUp(SupersetTestCase):
self.client.get(f"/superset/dashboard/{dash.id}/")
strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls())
expected = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
for slc in dash.slices
]
)
self.assertEqual(result, expected)
result = strategy.get_payloads()
expected = [
{"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices
]
self.assertCountEqual(result, expected)
def reset_tag(self, tag):
"""Remove associated object from tag, used to reset tests"""
@@ -95,57 +92,52 @@ class TestCacheWarmUp(SupersetTestCase):
@pytest.mark.usefixtures(
"load_unicode_dashboard_with_slice", "load_birth_names_dashboard_with_slices"
)
def test_dashboard_tags(self):
def test_dashboard_tags_strategy(self):
tag1 = get_tag("tag1", db.session, TagTypes.custom)
# delete first to make test idempotent
self.reset_tag(tag1)
strategy = DashboardTagsStrategy(["tag1"])
result = sorted(strategy.get_urls())
result = strategy.get_payloads()
expected = []
self.assertEqual(result, expected)
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
dash = self.get_dash_by_slug("births")
tag1_urls = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
for slc in dash.slices
]
)
tag1_urls = [{"chart_id": chart.id} for chart in dash.slices]
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
)
db.session.add(tagged_object)
db.session.commit()
self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
self.assertCountEqual(strategy.get_payloads(), tag1_urls)
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagTypes.custom)
self.reset_tag(tag2)
result = sorted(strategy.get_urls())
result = strategy.get_payloads()
expected = []
self.assertEqual(result, expected)
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
slc = dash.slices[0]
tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
object_id = slc.id
chart = dash.slices[0]
tag2_urls = [{"chart_id": chart.id}]
object_id = chart.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart
)
db.session.add(tagged_object)
db.session.commit()
result = sorted(strategy.get_urls())
self.assertEqual(result, tag2_urls)
result = strategy.get_payloads()
self.assertCountEqual(result, tag2_urls)
strategy = DashboardTagsStrategy(["tag1", "tag2"])
result = sorted(strategy.get_urls())
expected = sorted(tag1_urls + tag2_urls)
self.assertEqual(result, expected)
result = strategy.get_payloads()
expected = tag1_urls + tag2_urls
self.assertCountEqual(result, expected)