fix(celery cache warmup): add auth and use warm_up_cache endpoint (#21076)

This commit is contained in:
ʈᵃᵢ
2022-08-30 09:24:24 -07:00
committed by GitHub
parent b354f2265a
commit 04dd8d414d
4 changed files with 70 additions and 181 deletions

View File

@@ -69,6 +69,16 @@ REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", "1")
RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab") RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")
CACHE_CONFIG = {
"CACHE_TYPE": "redis",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_KEY_PREFIX": "superset_",
"CACHE_REDIS_HOST": REDIS_HOST,
"CACHE_REDIS_PORT": REDIS_PORT,
"CACHE_REDIS_DB": REDIS_RESULTS_DB,
}
DATA_CACHE_CONFIG = CACHE_CONFIG
class CeleryConfig(object): class CeleryConfig(object):
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"

View File

@@ -14,73 +14,36 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import json
import logging import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from urllib import request from urllib import request
from urllib.error import URLError from urllib.error import URLError
from celery.beat import SchedulingError
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy import and_, func from sqlalchemy import and_, func
from superset import app, db from superset import app, db, security_manager
from superset.extensions import celery_app from superset.extensions import celery_app
from superset.models.core import Log from superset.models.core import Log
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject from superset.models.tags import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime from superset.utils.date_parser import parse_human_datetime
from superset.views.utils import build_extra_filters from superset.utils.machine_auth import MachineAuthProvider
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def get_form_data( def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str:
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.
When a dashboard has `default_filters` they need to be added as extra
filters in the GET request for charts.
"""
form_data: Dict[str, Any] = {"slice_id": chart_id}
if dashboard is None or not dashboard.json_metadata:
return form_data
json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
if not default_filters:
return form_data
filter_scopes = json_metadata.get("filter_scopes", {})
layout = json.loads(dashboard.position_json or "{}")
if (
isinstance(layout, dict)
and isinstance(filter_scopes, dict)
and isinstance(default_filters, dict)
):
extra_filters = build_extra_filters(
layout, filter_scopes, default_filters, chart_id
)
if extra_filters:
form_data["extra_filters"] = extra_filters
return form_data
def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
"""Return external URL for warming up a given chart/table cache.""" """Return external URL for warming up a given chart/table cache."""
with app.test_request_context(): with app.test_request_context():
baseurl = ( baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
"{SUPERSET_WEBSERVER_PROTOCOL}://" url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}"
"{SUPERSET_WEBSERVER_ADDRESS}:" if dashboard:
"{SUPERSET_WEBSERVER_PORT}".format(**app.config) url += f"&dashboard_id={dashboard.id}"
) return url
return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
class Strategy: # pylint: disable=too-few-public-methods class Strategy: # pylint: disable=too-few-public-methods
@@ -179,8 +142,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards: for dashboard in dashboards:
for chart in dashboard.slices: for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard) urls.append(get_url(chart, dashboard))
urls.append(get_url(chart, form_data_with_filters))
return urls return urls
@@ -253,6 +215,30 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy] strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="fetch_url")
def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
logger.info("Fetching %s", url)
req = request.Request(url, headers=headers)
response = request.urlopen( # pylint: disable=consider-using-with
req, timeout=600
)
logger.info("Fetched %s, status code: %s", url, response.code)
if response.code == 200:
result = {"success": url, "response": response.read().decode("utf-8")}
else:
result = {"error": url, "status_code": response.code}
logger.error("Error fetching %s, status code: %s", url, response.code)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": url, "exception": str(err)}
return result
@celery_app.task(name="cache-warmup") @celery_app.task(name="cache-warmup")
def cache_warmup( def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any strategy_name: str, *args: Any, **kwargs: Any
@@ -282,14 +268,18 @@ def cache_warmup(
logger.exception(message) logger.exception(message)
return message return message
results: Dict[str, List[str]] = {"success": [], "errors": []} user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"])
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {"Cookie": f"session={cookies.get('session', '')}"}
results: Dict[str, List[str]] = {"scheduled": [], "errors": []}
for url in strategy.get_urls(): for url in strategy.get_urls():
try: try:
logger.info("Fetching %s", url) logger.info("Scheduling %s", url)
request.urlopen(url) # pylint: disable=consider-using-with fetch_url.delay(url, headers)
results["success"].append(url) results["scheduled"].append(url)
except URLError: except SchedulingError:
logger.exception("Error warming up cache!") logger.exception("Error scheduling fetch_url: %s", url)
results["errors"].append(url) results["errors"].append(url)
return results return results

View File

@@ -38,9 +38,9 @@ from superset.models.core import Log
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import ( from superset.tasks.cache import (
DashboardTagsStrategy, DashboardTagsStrategy,
get_form_data,
TopNDashboardsStrategy, TopNDashboardsStrategy,
) )
from superset.utils.urls import get_url_host
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
from .dashboard_utils import create_dashboard, create_slice, create_table_metadata from .dashboard_utils import create_dashboard, create_slice, create_table_metadata
@@ -49,7 +49,6 @@ from .fixtures.unicode_dashboard import (
load_unicode_data, load_unicode_data,
) )
URL_PREFIX = "http://0.0.0.0:8081"
mock_positions = { mock_positions = {
"DASHBOARD_VERSION_KEY": "v2", "DASHBOARD_VERSION_KEY": "v2",
@@ -69,128 +68,6 @@ mock_positions = {
class TestCacheWarmUp(SupersetTestCase): class TestCacheWarmUp(SupersetTestCase):
def test_get_form_data_chart_only(self):
chart_id = 1
result = get_form_data(chart_id, None)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_no_dashboard_metadata(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = None
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_immune_slice(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"filter_scopes": {
str(filter_box_id): {
"name": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
"default_filters": json.dumps(
{str(filter_box_id): {"name": ["Alice", "Bob"]}}
),
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_no_default_filters(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = json.dumps({})
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_immune_fields(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [{"col": "name", "op": "in", "val": ["Alice", "Bob"]}],
}
self.assertEqual(result, expected)
def test_get_form_data_no_extra_filters(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{str(filter_box_id): {"__time_range": "100 years ago : today"}}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
)
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [
{"col": "name", "op": "in", "val": ["Alice", "Bob"]},
{"col": "__time_range", "op": "==", "val": "100 years ago : today"},
],
}
self.assertEqual(result, expected)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_top_n_dashboards_strategy(self): def test_top_n_dashboards_strategy(self):
# create a top visited dashboard # create a top visited dashboard
@@ -202,7 +79,12 @@ class TestCacheWarmUp(SupersetTestCase):
strategy = TopNDashboardsStrategy(1) strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls()) result = sorted(strategy.get_urls())
expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices]) expected = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
for slc in dash.slices
]
)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def reset_tag(self, tag): def reset_tag(self, tag):
@@ -228,7 +110,12 @@ class TestCacheWarmUp(SupersetTestCase):
# tag dashboard 'births' with `tag1` # tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom) tag1 = get_tag("tag1", db.session, TagTypes.custom)
dash = self.get_dash_by_slug("births") dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices]) tag1_urls = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
for slc in dash.slices
]
)
tagged_object = TaggedObject( tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
) )
@@ -248,7 +135,7 @@ class TestCacheWarmUp(SupersetTestCase):
# tag first slice # tag first slice
dash = self.get_dash_by_slug("unicode-test") dash = self.get_dash_by_slug("unicode-test")
slc = dash.slices[0] slc = dash.slices[0]
tag2_urls = [f"{URL_PREFIX}{slc.url}"] tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
object_id = slc.id object_id = slc.id
tagged_object = TaggedObject( tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart

View File

@@ -73,6 +73,8 @@ FEATURE_FLAGS = {
"DRILL_TO_DETAIL": True, "DRILL_TO_DETAIL": True,
} }
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
def GET_FEATURE_FLAGS_FUNC(ff): def GET_FEATURE_FLAGS_FUNC(ff):
ff_copy = copy(ff) ff_copy = copy(ff)