Compare commits

...

4 Commits

Author SHA1 Message Date
Evan Rusackas
31b4e5f0dd docs: fix Celery beat configuration example
Update the cache warmup documentation to use the correct CeleryConfig
class pattern that's consistent with how Superset handles Celery
configuration.

- Use CeleryConfig.beat_schedule instead of CELERYBEAT_SCHEDULE
- Show how to extend CeleryConfig to preserve existing schedules
- Add missing import for CeleryConfig

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-09 15:43:20 -08:00
Evan Rusackas
e0448ef3a4 test: remove obsolete tests for HTTP-based cache warmup
The cache warmup mechanism has been refactored to use WebDriver instead
of direct HTTP requests. These tests were testing the old fetch_url and
fetch_csrf_token functions which no longer exist in cache.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 21:51:08 -08:00
Evan Rusackas
41d89de139 fix: address bot review feedback and fix failing tests
Addresses bot review comments:
- Add public create() and destroy() methods for WebDriverSelenium
  (fixes tests calling removed public method)
- Fix user session handling - re-authenticate if different user passed
- Change exception handling from URLError to WebDriverException
- Use wd.destroy() for proper WebDriver cleanup instead of del
- Fix N+1 query in DummyStrategy with selectinload for dashboard.slices
- Fix docs: config.py -> superset_config.py and add crontab import
- Fix tests: use get_url_host() instead of hardcoded localhost

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 21:49:31 -08:00
Evan Rusackas
d715176847 fix: cache warmup using WebDriver for reliable authentication
This PR adopts and improves PR #20387 by @ensky to fix cache warmup issues
where the task was unable to authenticate properly.

Changes:
- Use WebDriver (Selenium) to render dashboards for cache warmup instead of
  API calls, ensuring proper authentication and accurate cache population
- Add SUPERSET_CACHE_WARMUP_USER config for specifying the warmup user
- Refine WebDriverSelenium to support persistent driver instances,
  avoiding driver recreation for each URL
- Warm up entire dashboards instead of individual charts, since dashboard
  context affects how charts are cached
- Add documentation for Celery beat configuration

The WebDriver approach simulates real user behavior, ensuring caches are
populated exactly as users would experience them.

Fixes #9597, #18933
Originally by @ensky in PR #20387

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 21:49:31 -08:00
9 changed files with 209 additions and 373 deletions

View File

@@ -80,6 +80,39 @@ instead requires a cachelib object.
See [Async Queries via Celery](/docs/configuration/async-queries-celery) for details.
## Celery beat
Superset has a Celery task that will periodically warm up the cache based on different strategies.
To use it, add the following to your `superset_config.py`:
```python
from celery.schedules import crontab
from superset.config import CeleryConfig
# User that will be used to authenticate and render dashboards for cache warmup
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"
# Extend the default CeleryConfig to add cache warmup schedule
class CustomCeleryConfig(CeleryConfig):
beat_schedule = {
**CeleryConfig.beat_schedule,
'cache-warmup-hourly': {
'task': 'cache-warmup',
'schedule': crontab(minute=0, hour='*'), # hourly
'kwargs': {
'strategy_name': 'top_n_dashboards',
'top_n': 5,
'since': '7 days ago',
},
},
}
CELERY_CONFIG = CustomCeleryConfig
```
This will cache the top 5 most popular dashboards every hour. For other
strategies, check the `superset/tasks/cache.py` file.
## Caching Thumbnails
This is an optional feature that can be turned on by activating its [feature flag](/docs/configuration/configuring-superset#feature-flags) on config:

View File

@@ -1047,6 +1047,9 @@ THUMBNAIL_CACHE_CONFIG: CacheConfig = {
}
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
# Cache warmup user
SUPERSET_CACHE_WARMUP_USER = "admin"
# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())

View File

@@ -17,65 +17,46 @@
from __future__ import annotations
import logging
from typing import Any, Optional, TypedDict, Union
from urllib import request
from urllib.error import URLError
from typing import Any, Optional, Union
from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
from flask import current_app
from selenium.common.exceptions import WebDriverException
from sqlalchemy import and_, func
from sqlalchemy.orm import selectinload
from superset import db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.utils import fetch_csrf_token, get_executor
from superset.utils import json
from superset.utils.date_parser import parse_human_datetime
from superset.utils.machine_auth import MachineAuthProvider
from superset.utils.urls import get_url_path, is_secure_url
from superset.utils.webdriver import WebDriverSelenium
logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
class CacheWarmupPayload(TypedDict, total=False):
chart_id: int
dashboard_id: int | None
class CacheWarmupTask(TypedDict):
payload: CacheWarmupPayload
username: str | None
def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> CacheWarmupTask:
"""Return task for warming up a given chart/table cache."""
executors = current_app.config["CACHE_WARMUP_EXECUTORS"]
payload: CacheWarmupPayload = {"chart_id": chart.id}
if dashboard:
payload["dashboard_id"] = dashboard.id
username: str | None
try:
executor = get_executor(executors, chart)
username = executor[1]
except (ExecutorNotFoundError, InvalidExecutorError):
username = None
return {"payload": payload, "username": username}
def get_dash_url(dashboard: Dashboard) -> str:
"""Return external URL for warming up a given dashboard cache."""
with current_app.test_request_context():
baseurl = (
# when running this as an async task, drop the request context with
# app.test_request_context()
current_app.config.get("WEBDRIVER_BASEURL")
or "{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**current_app.config)
)
return f"{baseurl}{dashboard.url}"
class Strategy: # pylint: disable=too-few-public-methods
"""
A cache warm up strategy.
Each strategy defines a `get_tasks` method that returns a list of tasks to
send to the `/api/v1/chart/warm_up_cache` endpoint.
Each strategy defines a `get_urls` method that returns a list of dashboard URLs to
warm up using WebDriver.
Strategies can be configured in `superset/config.py`:
@@ -96,15 +77,16 @@ class Strategy: # pylint: disable=too-few-public-methods
def __init__(self) -> None:
pass
def get_tasks(self) -> list[CacheWarmupTask]:
raise NotImplementedError("Subclasses must implement get_tasks!")
def get_urls(self) -> list[str]:
raise NotImplementedError("Subclasses must implement get_urls!")
class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
"""
Warm up all charts.
Warm up all published dashboards.
This is a dummy strategy that will fetch all charts. Can be configured by:
This is a dummy strategy that will fetch all published dashboards.
Can be configured by:
beat_schedule = {
'cache-warmup-hourly': {
@@ -118,8 +100,16 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
name = "dummy"
def get_tasks(self) -> list[CacheWarmupTask]:
return [get_task(chart) for chart in db.session.query(Slice).all()]
def get_urls(self) -> list[str]:
# Use selectinload to avoid N+1 queries when checking dashboard.slices
dashboards = (
db.session.query(Dashboard)
.options(selectinload(Dashboard.slices))
.filter(Dashboard.published.is_(True))
.all()
)
return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]
class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
@@ -147,7 +137,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
self.top_n = top_n
self.since = parse_human_datetime(since) if since else None
def get_tasks(self) -> list[CacheWarmupTask]:
def get_urls(self) -> list[str]:
records = (
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
@@ -161,11 +151,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
return [
get_task(chart, dashboard)
for dashboard in dashboards
for chart in dashboard.slices
]
return [get_dash_url(dashboard) for dashboard in dashboards]
class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
@@ -190,8 +176,8 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
super().__init__()
self.tags = tags or []
def get_tasks(self) -> list[CacheWarmupTask]:
tasks = []
def get_urls(self) -> list[str]:
urls = []
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
@@ -211,73 +197,14 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
tasks.append(get_task(chart))
urls.append(get_dash_url(dashboard))
# add charts that are tagged
tagged_objects = (
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
tasks.append(get_task(chart))
return tasks
return urls
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="fetch_url")
def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
url = get_url_path("ChartRestApi.warm_up_cache")
if is_secure_url(url):
logger.info("URL '%s' is secure. Adding Referer header.", url)
headers.update({"Referer": url})
# Fetch CSRF token for API request
headers.update(fetch_csrf_token(headers))
logger.info("Fetching %s with payload %s", url, data)
req = request.Request( # noqa: S310
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
)
response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310
req, timeout=600
)
logger.info(
"Fetched %s with payload %s, status code: %s", url, data, response.code
)
if response.code == 200:
result = {"success": data, "response": response.read().decode("utf-8")}
else:
result = {"error": data, "status_code": response.code}
logger.error(
"Error fetching %s with payload %s, status code: %s",
url,
data,
response.code,
)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": data, "exception": str(err)}
return result
@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
@@ -285,7 +212,7 @@ def cache_warmup(
"""
Warm up cache.
This task periodically hits charts to warm up the cache.
This task periodically hits dashboards to warm up the cache.
"""
logger.info("Loading strategy")
@@ -307,25 +234,33 @@ def cache_warmup(
logger.exception(message)
return message
results: dict[str, list[str]] = {"scheduled": [], "errors": []}
for task in strategy.get_tasks():
username = task["username"]
payload = json.dumps(task["payload"])
if username:
results: dict[str, list[str]] = {"success": [], "errors": []}
user = security_manager.find_user(
username=current_app.config["SUPERSET_CACHE_WARMUP_USER"]
)
if not user:
message = (
f"Cache warmup user '{current_app.config['SUPERSET_CACHE_WARMUP_USER']}' "
"not found. Please configure SUPERSET_CACHE_WARMUP_USER with a valid "
"username."
)
logger.error(message)
return message
wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user)
try:
for url in strategy.get_urls():
try:
user = security_manager.get_user_by_username(username)
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {
"Cookie": "session=%s" % cookies.get("session", ""),
"Content-Type": "application/json",
}
logger.info("Scheduling %s", payload)
fetch_url.delay(payload, headers)
results["scheduled"].append(payload)
except SchedulingError:
logger.exception("Error scheduling fetch_url for payload: %s", payload)
results["errors"].append(payload)
else:
logger.warning("Executor not found for %s", payload)
logger.info("Fetching %s", url)
wd.get_screenshot(url, "grid-container")
results["success"].append(url)
except (WebDriverException, Exception) as ex: # noqa: BLE001
logger.exception("Error warming up cache for %s: %s", url, ex)
results["errors"].append(url)
finally:
# Ensure WebDriver is properly cleaned up
wd.destroy()
return results

View File

@@ -33,8 +33,8 @@ from superset.utils.urls import modify_url_query
from superset.utils.webdriver import (
ChartStandaloneMode,
DashboardStandaloneMode,
WebDriver,
WebDriverPlaywright,
WebDriverProxy,
WebDriverSelenium,
WindowSize,
)
@@ -188,7 +188,9 @@ class BaseScreenshot:
self.url = url
self.screenshot = None
def driver(self, window_size: WindowSize | None = None) -> WebDriver:
def driver(
self, window_size: WindowSize | None = None, user: User | None = None
) -> WebDriverProxy:
window_size = window_size or self.window_size
if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
# Try to use Playwright if available (supports WebGL/DeckGL, unlike Cypress)
@@ -204,13 +206,13 @@ class BaseScreenshot:
)
# Use Selenium as default/fallback
return WebDriverSelenium(self.driver_type, window_size)
return WebDriverSelenium(self.driver_type, window_size, user)
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
driver = self.driver(window_size, user)
self.screenshot = driver.get_screenshot(self.url, self.element)
return self.screenshot
def get_cache_key(

View File

@@ -159,7 +159,9 @@ class WebDriverProxy(ABC):
self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"]
@abstractmethod
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
def get_screenshot(
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
"""
Run webdriver and return a screenshot
"""
@@ -224,7 +226,7 @@ class WebDriverPlaywright(WebDriverProxy):
return element.screenshot()
def get_screenshot( # pylint: disable=too-many-locals, too-many-statements # noqa: C901
self, url: str, element_name: str, user: User
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
if not PLAYWRIGHT_AVAILABLE:
logger.info(
@@ -252,7 +254,8 @@ class WebDriverPlaywright(WebDriverProxy):
context.set_default_timeout(
app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"]
)
self.auth(user, context)
if user:
self.auth(user, context)
page = context.new_page()
try:
page.goto(
@@ -318,7 +321,7 @@ class WebDriverPlaywright(WebDriverProxy):
logger.debug(
"Taking a PNG screenshot of url %s as user %s",
url,
user.username,
user.username if user else "None",
)
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page)
@@ -399,6 +402,29 @@ class WebDriverPlaywright(WebDriverProxy):
class WebDriverSelenium(WebDriverProxy):
def __init__(
self,
driver_type: str,
window: WindowSize | None = None,
user: User | None = None,
):
super().__init__(driver_type, window)
self._user = user
self._driver: WebDriver | None = None
def __del__(self) -> None:
self._destroy()
@property
def driver(self) -> WebDriver:
if not self._driver:
self._driver = self._create()
assert self._driver # for mypy
self._driver.set_window_size(*self._window)
if self._user:
self._auth(self._user)
return self._driver
def _create_firefox_driver(
self, pixel_density: float
) -> tuple[type[WebDriver], type[Service], dict[str, Any]]:
@@ -456,6 +482,22 @@ class WebDriverSelenium(WebDriverProxy):
return config
def create(self) -> WebDriver:
"""Create and return the WebDriver instance.
This is the public interface for creating the driver. It wraps
the internal _create method for backward compatibility.
"""
return self._create()
def destroy(self) -> None:
"""Destroy the WebDriver instance.
This is the public interface for cleanup. It wraps the internal
_destroy method and should be called when done with the driver.
"""
self._destroy()
def _create(self) -> WebDriver:
pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
# Get driver class and initial kwargs based on driver type
@@ -516,26 +558,32 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug("Init selenium driver")
return driver_class(**kwargs)
def auth(self, user: User) -> WebDriver:
driver = self.create()
def _auth(self, user: User) -> WebDriver:
return machine_auth_provider_factory.instance.authenticate_webdriver(
driver, user
self.driver, user
)
@staticmethod
def destroy(driver: WebDriver, tries: int = 2) -> None:
def _destroy(self) -> None:
"""Destroy a driver"""
if not self._driver:
return
# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions
try:
retry_call(driver.close, max_tries=tries)
retry_call(
self._driver.close,
max_tries=app.config["SCREENSHOT_SELENIUM_RETRIES"],
)
except Exception: # pylint: disable=broad-except # noqa: S110
pass
try:
driver.quit()
self._driver.quit()
except Exception: # pylint: disable=broad-except # noqa: S110
pass
self._driver = None
@staticmethod
def find_unexpected_errors(driver: WebDriver) -> list[str]:
error_messages = []
@@ -592,10 +640,15 @@ class WebDriverSelenium(WebDriverProxy):
return error_messages
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
def get_screenshot( # noqa: C901
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
# Re-authenticate if a different user is passed
if user and user != self._user:
self._user = user
if self._driver:
self._auth(user)
self.driver.get(url)
img: bytes | None = None
selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"]
logger.debug("Sleeping for %i seconds", selenium_headstart)
@@ -607,9 +660,9 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug(
"Wait for the presence of %s at url: %s", element_name, url
)
element = WebDriverWait(driver, self._screenshot_locate_wait).until(
EC.presence_of_element_located((By.CLASS_NAME, element_name))
)
element = WebDriverWait(
self.driver, self._screenshot_locate_wait
).until(EC.presence_of_element_located((By.CLASS_NAME, element_name)))
except TimeoutException:
logger.exception("Selenium timed out requesting url %s", url)
raise
@@ -617,7 +670,7 @@ class WebDriverSelenium(WebDriverProxy):
try:
# chart containers didn't render
logger.debug("Wait for chart containers to draw at url: %s", url)
WebDriverWait(driver, self._screenshot_locate_wait).until(
WebDriverWait(self.driver, self._screenshot_locate_wait).until(
EC.visibility_of_all_elements_located(
(By.CLASS_NAME, "chart-container")
)
@@ -626,7 +679,7 @@ class WebDriverSelenium(WebDriverProxy):
logger.info("Timeout Exception caught")
# Fallback to allow a screenshot of an empty dashboard
try:
WebDriverWait(driver, 0).until(
WebDriverWait(self.driver, 0).until(
EC.visibility_of_all_elements_located(
(By.CLASS_NAME, "grid-container")
)
@@ -643,7 +696,7 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug(
"Wait for loading element of charts to be gone at url: %s", url
)
WebDriverWait(driver, self._screenshot_load_wait).until_not(
WebDriverWait(self.driver, self._screenshot_load_wait).until_not(
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
)
except TimeoutException:
@@ -658,11 +711,13 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug(
"Taking a PNG screenshot of url %s as user %s",
url,
user.username,
user.username if user else "None",
)
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver)
unexpected_errors = WebDriverSelenium.find_unexpected_errors(
self.driver
)
if unexpected_errors:
logger.warning(
"%i errors found in the screenshot. URL: %s. Errors are: %s",
@@ -689,6 +744,4 @@ class WebDriverSelenium(WebDriverProxy):
"Encountered an unexpected error when requesting url %s", url
)
raise
finally:
self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img

View File

@@ -39,7 +39,7 @@ from superset.tasks.cache import (
DashboardTagsStrategy,
TopNDashboardsStrategy,
)
from superset.utils.urls import get_url_host # noqa: F401
from superset.utils.urls import get_url_host
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
@@ -82,15 +82,9 @@ class TestCacheWarmUp(SupersetTestCase):
self.client.get(f"/superset/dashboard/{dash.id}/")
strategy = TopNDashboardsStrategy(1)
result = strategy.get_tasks()
expected = [
{
"payload": {"chart_id": chart.id, "dashboard_id": dash.id},
"username": "admin",
}
for chart in dash.slices
]
assert len(result) == len(expected)
result = sorted(strategy.get_urls())
expected = sorted([f"{get_url_host()}{dash.url}"])
assert result == expected
def reset_tag(self, tag):
"""Remove associated object from tag, used to reset tests"""
@@ -108,39 +102,27 @@ class TestCacheWarmUp(SupersetTestCase):
self.reset_tag(tag1)
strategy = DashboardTagsStrategy(["tag1"])
assert strategy.get_tasks() == []
result = sorted(strategy.get_urls())
expected = []
assert result == expected
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagType.custom)
dash = self.get_dash_by_slug("births")
tag1_payloads = [{"chart_id": chart.id} for chart in dash.slices]
tag1_urls = [f"{get_url_host()}{dash.url}"]
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectType.dashboard
)
db.session.add(tagged_object)
db.session.commit()
assert len(strategy.get_tasks()) == len(tag1_payloads)
result = sorted(strategy.get_urls())
assert result == tag1_urls
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagType.custom)
self.reset_tag(tag2)
assert strategy.get_tasks() == []
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
chart = dash.slices[0]
tag2_payloads = [{"chart_id": chart.id}]
object_id = chart.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectType.chart
)
db.session.add(tagged_object)
db.session.commit()
assert len(strategy.get_tasks()) == len(tag2_payloads)
strategy = DashboardTagsStrategy(["tag1", "tag2"])
assert len(strategy.get_tasks()) == len(tag1_payloads + tag2_payloads)
result = sorted(strategy.get_urls())
expected = []
assert result == expected

View File

@@ -1,95 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from tests.integration_tests.test_app import app
@pytest.mark.parametrize(
"base_url, expected_referer",
[
("http://base-url", None),
("http://base-url/", None),
("https://base-url", "https://base-url/api/v1/chart/warm_up_cache"),
("https://base-url/", "https://base-url/api/v1/chart/warm_up_cache"),
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
)
@mock.patch("superset.tasks.cache.fetch_csrf_token")
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
@mock.patch("superset.tasks.cache.is_secure_url")
def test_fetch_url(
mock_is_secure_url,
mock_urlopen,
mock_request_cls,
mock_fetch_csrf_token,
base_url,
expected_referer,
):
from superset.tasks.cache import fetch_url
mock_request = mock.MagicMock()
mock_request_cls.return_value = mock_request
mock_urlopen.return_value = mock.MagicMock()
mock_urlopen.return_value.code = 200
# Mock the URL validation to return True for HTTPS and False for HTTP
mock_is_secure_url.return_value = base_url.startswith("https")
initial_headers = {"Cookie": "cookie", "key": "value"}
csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
# Conditionally add the Referer header and assert its presence
if expected_referer:
csrf_headers = csrf_headers | {"Referer": expected_referer}
assert csrf_headers["Referer"] == expected_referer
mock_fetch_csrf_token.return_value = csrf_headers
app.config["WEBDRIVER_BASEURL"] = base_url
data = "data"
data_encoded = b"data"
result = fetch_url(data, initial_headers)
expected_url = (
f"{base_url}/api/v1/chart/warm_up_cache"
if not base_url.endswith("/")
else f"{base_url}api/v1/chart/warm_up_cache"
)
mock_fetch_csrf_token.assert_called_once_with(initial_headers)
mock_request_cls.assert_called_once_with(
expected_url, # Use the dynamic URL based on base_url
data=data_encoded,
headers=csrf_headers,
method="PUT",
)
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
assert data == result["success"]

View File

@@ -1,77 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from tests.integration_tests.test_app import app
@pytest.mark.parametrize(
"base_url",
[
"http://base-url",
"http://base-url/",
"https://base-url",
"https://base-url/",
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
)
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context):
from superset.tasks.utils import fetch_csrf_token
mock_request = mock.MagicMock()
mock_request_cls.return_value = mock_request
mock_response = mock.MagicMock()
mock_urlopen.return_value.__enter__.return_value = mock_response
mock_response.status = 200
mock_response.read.return_value = b'{"result": "csrf_token"}'
mock_response.headers.get_all.return_value = [
"session=new_session_cookie",
"async-token=websocket_cookie",
]
app.config["WEBDRIVER_BASEURL"] = base_url
headers = {"Cookie": "original_session_cookie"}
result_headers = fetch_csrf_token(headers)
expected_url = (
f"{base_url}/api/v1/security/csrf_token/"
if not base_url.endswith("/")
else f"{base_url}api/v1/security/csrf_token/"
)
mock_request_cls.assert_called_with(
expected_url,
headers=headers,
method="GET",
)
assert result_headers["X-CSRF-Token"] == "csrf_token"
assert result_headers["Cookie"] == "session=new_session_cookie" # Updated assertion
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)

View File

@@ -147,31 +147,31 @@ class TestWebDriverSelenium(SupersetTestCase):
def test_screenshot_selenium_headstart(
self, mock_sleep, mock_webdriver, mock_webdriver_wait
):
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
app.config["SCREENSHOT_SELENIUM_HEADSTART"] = 5
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_sleep.call_args_list[0] == call(5)
@patch("superset.utils.webdriver.WebDriverWait")
@patch("superset.utils.webdriver.firefox")
def test_screenshot_selenium_locate_wait(self, mock_webdriver, mock_webdriver_wait):
app.config["SCREENSHOT_LOCATE_WAIT"] = 15
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_webdriver_wait.call_args_list[0] == call(ANY, 15)
@patch("superset.utils.webdriver.WebDriverWait")
@patch("superset.utils.webdriver.firefox")
def test_screenshot_selenium_load_wait(self, mock_webdriver, mock_webdriver_wait):
app.config["SCREENSHOT_LOAD_WAIT"] = 15
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_webdriver_wait.call_args_list[2] == call(ANY, 15)
@patch("superset.utils.webdriver.WebDriverWait")
@@ -180,11 +180,11 @@ class TestWebDriverSelenium(SupersetTestCase):
def test_screenshot_selenium_animation_wait(
self, mock_sleep, mock_webdriver, mock_webdriver_wait
):
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] = 4
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_sleep.call_args_list[1] == call(4)