Compare commits

...

6 Commits

Author SHA1 Message Date
Evan Rusackas
6fae0ff2f1 address review: default SUPERSET_CACHE_WARMUP_USER to None, fail fast
Warmup ran as "admin" by default, which is the highest-privilege user
in a fresh install. If an operator enables the cache-warmup Celery beat
without explicit configuration, that default silently renders dashboards
as admin — larger blast radius than needed.

Now the default is None, and cache_warmup() returns a clear error
message pointing operators at SUPERSET_CACHE_WARMUP_USER before it
even tries to look up a user. Matches the reviewer's least-privilege
suggestion.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-23 08:54:58 -07:00
Evan Rusackas
c5475c3a2d address review: strip trailing slash in cache-warmup test expectations
get_dash_url() now rstrips the trailing slash from WEBDRIVER_BASEURL, so
the test expectations need the same treatment — otherwise a baseurl that
ends in / produces double-slash URLs that no longer match strategy
output. Fixes both test_top_n_dashboards_strategy and
test_dashboard_tags_strategy.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 12:39:25 -07:00
Evan Rusackas
7476838f26 address review: forward user to driver.get_screenshot
WebDriverPlaywright's get_screenshot still needs the user argument to
authenticate its browser context; without it the Playwright path renders
private dashboards as unauthenticated pages. WebDriverSelenium already
accepts the optional user kwarg and re-authenticates in-place if it
differs from the stored one.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 12:38:12 -07:00
Evan Rusackas
6b4ab27f01 address review: replace assert in _auth with explicit RuntimeError
Assertions can be disabled at runtime, so use an explicit check and
raise instead — matches how driver creation failure is already handled.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-22 12:35:42 -07:00
Superset Dev
776ab3cb14 fix(webdriver): address review feedback on cache warmup WebDriver changes
- Fix _auth() to authenticate self._driver in-place instead of creating
  a second, leaked driver (critical bug: persistent driver was never authenticated)
- Replace assert with explicit RuntimeError for driver creation failure
- Fix get_dash_url() to strip trailing slash from WEBDRIVER_BASEURL to
  avoid double-slash URLs (e.g. http://host//superset/dashboard/1/)
- Fix BaseScreenshot.get_screenshot() to call driver.destroy() in a
  try/finally block, preventing Selenium process leaks for one-off screenshots
- Fix webdriver_pool._destroy_driver() to directly call close()/quit()
  on the raw WebDriver since destroy() is now an instance method, not static

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:11:23 -07:00
Evan Rusackas
2ac03e438c fix: cache warmup using WebDriver for reliable authentication
Adopted from PR #34525 by @rusackas (originally PR #20387 by @ensky).
  Rebased on master with conflict resolution.

  Changes:
  - Use WebDriver (Selenium) to render dashboards for cache warmup
  - Add SUPERSET_CACHE_WARMUP_USER config for specifying the warmup user
  - Support persistent WebDriver instances for efficiency
  - Warm up entire dashboards instead of individual charts
  - Add Celery beat configuration documentation
  - Remove obsolete HTTP-based cache warmup tests

  Co-Authored-By: Evan Rusackas <evan@rusackas.com>
  Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-05 09:21:38 -08:00
10 changed files with 223 additions and 366 deletions

View File

@@ -86,6 +86,39 @@ instead requires a cachelib object.
See [Async Queries via Celery](/admin-docs/configuration/async-queries-celery) for details.
## Celery beat
Superset has a Celery task that will periodically warm up the cache based on different strategies.
To use it, add the following to your `superset_config.py`:
```python
from celery.schedules import crontab
from superset.config import CeleryConfig
# User that will be used to authenticate and render dashboards for cache warmup
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"
# Extend the default CeleryConfig to add cache warmup schedule
class CustomCeleryConfig(CeleryConfig):
beat_schedule = {
**CeleryConfig.beat_schedule,
'cache-warmup-hourly': {
'task': 'cache-warmup',
'schedule': crontab(minute=0, hour='*'), # hourly
'kwargs': {
'strategy_name': 'top_n_dashboards',
'top_n': 5,
'since': '7 days ago',
},
},
}
CELERY_CONFIG = CustomCeleryConfig
```
This will cache the top 5 most popular dashboards every hour. For other
strategies, check the `superset/tasks/cache.py` file.
## Caching Thumbnails
This is an optional feature that can be turned on by activating its [feature flag](/admin-docs/configuration/configuring-superset#feature-flags) on config:

View File

@@ -1056,6 +1056,11 @@ THUMBNAIL_CACHE_CONFIG: CacheConfig = {
}
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
# Cache warmup user — must be set explicitly before enabling the cache-warmup
# Celery task. Intentionally defaults to None so operators pick a dedicated
# least-privilege user rather than inadvertently running warmup as "admin".
SUPERSET_CACHE_WARMUP_USER: str | None = None
# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())

View File

@@ -228,7 +228,11 @@ class WebDriverPool:
def _destroy_driver(self, pooled_driver: PooledWebDriver) -> None:
"""Safely destroy a WebDriver instance"""
try:
WebDriverSelenium.destroy(pooled_driver.driver)
try:
pooled_driver.driver.close()
except Exception: # pylint: disable=broad-except # noqa: S110
pass
pooled_driver.driver.quit()
self._stats["destroyed"] += 1
logger.debug("Destroyed WebDriver instance")
except Exception as e:

View File

@@ -17,65 +17,46 @@
from __future__ import annotations
import logging
from typing import Any, Optional, TypedDict, Union
from urllib import request
from urllib.error import URLError
from typing import Any, Optional, Union
from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
from flask import current_app
from selenium.common.exceptions import WebDriverException
from sqlalchemy import and_, func
from sqlalchemy.orm import selectinload
from superset import db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
from superset.tasks.utils import fetch_csrf_token, get_executor
from superset.utils import json
from superset.utils.date_parser import parse_human_datetime
from superset.utils.machine_auth import MachineAuthProvider
from superset.utils.urls import get_url_path, is_secure_url
from superset.utils.webdriver import WebDriverSelenium
logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
class CacheWarmupPayload(TypedDict, total=False):
chart_id: int
dashboard_id: int | None
class CacheWarmupTask(TypedDict):
payload: CacheWarmupPayload
username: str | None
def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> CacheWarmupTask:
"""Return task for warming up a given chart/table cache."""
executors = current_app.config["CACHE_WARMUP_EXECUTORS"]
payload: CacheWarmupPayload = {"chart_id": chart.id}
if dashboard:
payload["dashboard_id"] = dashboard.id
username: str | None
try:
executor = get_executor(executors, chart)
username = executor[1]
except (ExecutorNotFoundError, InvalidExecutorError):
username = None
return {"payload": payload, "username": username}
def get_dash_url(dashboard: Dashboard) -> str:
"""Return external URL for warming up a given dashboard cache."""
with current_app.test_request_context():
baseurl = (
# when running this as an async task, drop the request context with
# app.test_request_context()
current_app.config.get("WEBDRIVER_BASEURL")
or "{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**current_app.config)
)
return f"{baseurl.rstrip('/')}{dashboard.url}"
class Strategy: # pylint: disable=too-few-public-methods
"""
A cache warm up strategy.
Each strategy defines a `get_tasks` method that returns a list of tasks to
send to the `/api/v1/chart/warm_up_cache` endpoint.
Each strategy defines a `get_urls` method that returns a list of dashboard URLs to
warm up using WebDriver.
Strategies can be configured in `superset/config.py`:
@@ -96,15 +77,16 @@ class Strategy: # pylint: disable=too-few-public-methods
def __init__(self) -> None:
pass
def get_tasks(self) -> list[CacheWarmupTask]:
raise NotImplementedError("Subclasses must implement get_tasks!")
def get_urls(self) -> list[str]:
raise NotImplementedError("Subclasses must implement get_urls!")
class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
"""
Warm up all charts.
Warm up all published dashboards.
This is a dummy strategy that will fetch all charts. Can be configured by:
This is a dummy strategy that will fetch all published dashboards.
Can be configured by:
beat_schedule = {
'cache-warmup-hourly': {
@@ -118,8 +100,16 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
name = "dummy"
def get_tasks(self) -> list[CacheWarmupTask]:
return [get_task(chart) for chart in db.session.query(Slice).all()]
def get_urls(self) -> list[str]:
# Use selectinload to avoid N+1 queries when checking dashboard.slices
dashboards = (
db.session.query(Dashboard)
.options(selectinload(Dashboard.slices))
.filter(Dashboard.published.is_(True))
.all()
)
return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]
class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
@@ -147,7 +137,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
self.top_n = top_n
self.since = parse_human_datetime(since) if since else None
def get_tasks(self) -> list[CacheWarmupTask]:
def get_urls(self) -> list[str]:
records = (
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
@@ -161,11 +151,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
return [
get_task(chart, dashboard)
for dashboard in dashboards
for chart in dashboard.slices
]
return [get_dash_url(dashboard) for dashboard in dashboards]
class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
@@ -190,8 +176,8 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
super().__init__()
self.tags = tags or []
def get_tasks(self) -> list[CacheWarmupTask]:
tasks = []
def get_urls(self) -> list[str]:
urls = []
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
@@ -211,73 +197,14 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
tasks.append(get_task(chart))
urls.append(get_dash_url(dashboard))
# add charts that are tagged
tagged_objects = (
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
tasks.append(get_task(chart))
return tasks
return urls
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="fetch_url")
def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
url = get_url_path("ChartRestApi.warm_up_cache")
if is_secure_url(url):
logger.info("URL '%s' is secure. Adding Referer header.", url)
headers.update({"Referer": url})
# Fetch CSRF token for API request
headers.update(fetch_csrf_token(headers))
logger.info("Fetching %s with payload %s", url, data)
req = request.Request( # noqa: S310
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
)
response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310
req, timeout=600
)
logger.info(
"Fetched %s with payload %s, status code: %s", url, data, response.code
)
if response.code == 200:
result = {"success": data, "response": response.read().decode("utf-8")}
else:
result = {"error": data, "status_code": response.code}
logger.error(
"Error fetching %s with payload %s, status code: %s",
url,
data,
response.code,
)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": data, "exception": str(err)}
return result
@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
@@ -285,7 +212,7 @@ def cache_warmup(
"""
Warm up cache.
This task periodically hits charts to warm up the cache.
This task periodically hits dashboards to warm up the cache.
"""
logger.info("Loading strategy")
@@ -307,25 +234,39 @@ def cache_warmup(
logger.exception(message)
return message
results: dict[str, list[str]] = {"scheduled": [], "errors": []}
for task in strategy.get_tasks():
username = task["username"]
payload = json.dumps(task["payload"])
if username:
results: dict[str, list[str]] = {"success": [], "errors": []}
warmup_username = current_app.config.get("SUPERSET_CACHE_WARMUP_USER")
if not warmup_username:
message = (
"SUPERSET_CACHE_WARMUP_USER is not configured. Set it to a dedicated "
"least-privilege user with access to the dashboards you want warmed up."
)
logger.error(message)
return message
user = security_manager.find_user(username=warmup_username)
if not user:
message = (
f"Cache warmup user '{warmup_username}' not found. Please configure "
"SUPERSET_CACHE_WARMUP_USER with a valid username."
)
logger.error(message)
return message
wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user)
try:
for url in strategy.get_urls():
try:
user = security_manager.get_user_by_username(username)
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {
"Cookie": "session=%s" % cookies.get("session", ""),
"Content-Type": "application/json",
}
logger.info("Scheduling %s", payload)
fetch_url.delay(payload, headers)
results["scheduled"].append(payload)
except SchedulingError:
logger.exception("Error scheduling fetch_url for payload: %s", payload)
results["errors"].append(payload)
else:
logger.warning("Executor not found for %s", payload)
logger.info("Fetching %s", url)
wd.get_screenshot(url, "grid-container")
results["success"].append(url)
except (WebDriverException, Exception) as ex: # noqa: BLE001
logger.exception("Error warming up cache for %s: %s", url, ex)
results["errors"].append(url)
finally:
# Ensure WebDriver is properly cleaned up
wd.destroy()
return results

View File

@@ -33,8 +33,8 @@ from superset.utils.urls import modify_url_query
from superset.utils.webdriver import (
ChartStandaloneMode,
DashboardStandaloneMode,
WebDriver,
WebDriverPlaywright,
WebDriverProxy,
WebDriverSelenium,
WindowSize,
)
@@ -188,7 +188,9 @@ class BaseScreenshot:
self.url = url
self.screenshot = None
def driver(self, window_size: WindowSize | None = None) -> WebDriver:
def driver(
self, window_size: WindowSize | None = None, user: User | None = None
) -> WebDriverProxy:
window_size = window_size or self.window_size
if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
# Try to use Playwright if available (supports WebGL/DeckGL, unlike Cypress)
@@ -204,13 +206,17 @@ class BaseScreenshot:
)
# Use Selenium as default/fallback
return WebDriverSelenium(self.driver_type, window_size)
return WebDriverSelenium(self.driver_type, window_size, user)
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
driver = self.driver(window_size, user)
try:
self.screenshot = driver.get_screenshot(self.url, self.element, user)
finally:
if isinstance(driver, WebDriverSelenium):
driver.destroy()
return self.screenshot
def get_cache_key(

View File

@@ -159,7 +159,9 @@ class WebDriverProxy(ABC):
self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"]
@abstractmethod
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
def get_screenshot(
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
"""
Run webdriver and return a screenshot
"""
@@ -224,7 +226,7 @@ class WebDriverPlaywright(WebDriverProxy):
return element.screenshot()
def get_screenshot( # pylint: disable=too-many-locals, too-many-statements # noqa: C901
self, url: str, element_name: str, user: User
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
if not PLAYWRIGHT_AVAILABLE:
logger.info(
@@ -252,7 +254,8 @@ class WebDriverPlaywright(WebDriverProxy):
context.set_default_timeout(
app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"]
)
self.auth(user, context)
if user:
self.auth(user, context)
page = context.new_page()
try:
page.goto(
@@ -318,7 +321,7 @@ class WebDriverPlaywright(WebDriverProxy):
logger.debug(
"Taking a PNG screenshot of url %s as user %s",
url,
user.username,
user.username if user else "None",
)
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page)
@@ -399,6 +402,30 @@ class WebDriverPlaywright(WebDriverProxy):
class WebDriverSelenium(WebDriverProxy):
def __init__(
self,
driver_type: str,
window: WindowSize | None = None,
user: User | None = None,
):
super().__init__(driver_type, window)
self._user = user
self._driver: WebDriver | None = None
def __del__(self) -> None:
self._destroy()
@property
def driver(self) -> WebDriver:
if not self._driver:
self._driver = self._create()
if not self._driver:
raise RuntimeError("WebDriver creation failed")
self._driver.set_window_size(*self._window)
if self._user:
self._auth(self._user)
return self._driver
def _create_firefox_driver(
self, pixel_density: float
) -> tuple[type[WebDriver], type[Service], dict[str, Any]]:
@@ -456,6 +483,22 @@ class WebDriverSelenium(WebDriverProxy):
return config
def create(self) -> WebDriver:
"""Create and return the WebDriver instance.
This is the public interface for creating the driver. It wraps
the internal _create method for backward compatibility.
"""
return self._create()
def destroy(self) -> None:
"""Destroy the WebDriver instance.
This is the public interface for cleanup. It wraps the internal
_destroy method and should be called when done with the driver.
"""
self._destroy()
def _create(self) -> WebDriver:
pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
# Get driver class and initial kwargs based on driver type
@@ -516,25 +559,29 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug("Init selenium driver")
return driver_class(**kwargs)
def auth(self, user: User) -> WebDriver:
driver = self.create()
return machine_auth_provider_factory.instance.authenticate_webdriver(
driver, user
def _auth(self, user: User) -> None:
"""Authenticate the persistent driver in-place."""
if self._driver is None:
raise RuntimeError("WebDriver is not initialized")
machine_auth_provider_factory.instance.authenticate_webdriver(
self._driver, user
)
@staticmethod
def destroy(driver: WebDriver, tries: int = 2) -> None:
"""Destroy a driver"""
def _destroy(self, tries: int = 2) -> None:
"""Destroy the persistent driver"""
if not self._driver:
return
# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions
try:
retry_call(driver.close, max_tries=tries)
retry_call(self._driver.close, max_tries=tries)
except Exception: # pylint: disable=broad-except # noqa: S110
pass
try:
driver.quit()
self._driver.quit()
except Exception: # pylint: disable=broad-except # noqa: S110
pass
self._driver = None
@staticmethod
def find_unexpected_errors(driver: WebDriver) -> list[str]:
@@ -592,9 +639,16 @@ class WebDriverSelenium(WebDriverProxy):
return error_messages
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
driver = self.auth(user)
driver.set_window_size(*self._window)
def get_screenshot( # noqa: C901
self, url: str, element_name: str, user: User | None = None
) -> bytes | None:
# If a user is passed explicitly and differs from the stored user,
# update and re-authenticate
if user and user != self._user:
self._user = user
if self._driver:
self._destroy()
driver = self.driver
driver.get(url)
img: bytes | None = None
selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"]
@@ -663,7 +717,7 @@ class WebDriverSelenium(WebDriverProxy):
logger.debug(
"Taking a PNG screenshot of url %s as user %s",
url,
user.username,
self._user.username if self._user else "None",
)
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
@@ -698,5 +752,9 @@ class WebDriverSelenium(WebDriverProxy):
logger.warning("exception in webdriver", exc_info=ex)
raise
finally:
self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"])
# When used as a persistent driver (e.g., cache warmup),
# cleanup is handled externally via destroy().
# When used for one-off screenshots, the caller or __del__
# handles cleanup.
pass
return img

View File

@@ -82,15 +82,9 @@ class TestCacheWarmUp(SupersetTestCase):
self.client.get(f"/superset/dashboard/{dash.id}/")
strategy = TopNDashboardsStrategy(1)
result = strategy.get_tasks()
expected = [
{
"payload": {"chart_id": chart.id, "dashboard_id": dash.id},
"username": "admin",
}
for chart in dash.slices
]
assert len(result) == len(expected)
result = sorted(strategy.get_urls())
expected = sorted([f"{get_url_host().rstrip('/')}{dash.url}"])
assert result == expected
def reset_tag(self, tag):
"""Remove associated object from tag, used to reset tests"""
@@ -108,39 +102,27 @@ class TestCacheWarmUp(SupersetTestCase):
self.reset_tag(tag1)
strategy = DashboardTagsStrategy(["tag1"])
assert strategy.get_tasks() == []
result = sorted(strategy.get_urls())
expected = []
assert result == expected
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagType.custom)
dash = self.get_dash_by_slug("births")
tag1_payloads = [{"chart_id": chart.id} for chart in dash.slices]
tag1_urls = [f"{get_url_host().rstrip('/')}{dash.url}"]
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectType.dashboard
)
db.session.add(tagged_object)
db.session.commit()
assert len(strategy.get_tasks()) == len(tag1_payloads)
result = sorted(strategy.get_urls())
assert result == tag1_urls
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagType.custom)
self.reset_tag(tag2)
assert strategy.get_tasks() == []
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
chart = dash.slices[0]
tag2_payloads = [{"chart_id": chart.id}]
object_id = chart.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectType.chart
)
db.session.add(tagged_object)
db.session.commit()
assert len(strategy.get_tasks()) == len(tag2_payloads)
strategy = DashboardTagsStrategy(["tag1", "tag2"])
assert len(strategy.get_tasks()) == len(tag1_payloads + tag2_payloads)
result = sorted(strategy.get_urls())
expected = []
assert result == expected

View File

@@ -1,95 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from tests.integration_tests.test_app import app
@pytest.mark.parametrize(
"base_url, expected_referer",
[
("http://base-url", None),
("http://base-url/", None),
("https://base-url", "https://base-url/api/v1/chart/warm_up_cache"),
("https://base-url/", "https://base-url/api/v1/chart/warm_up_cache"),
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
)
@mock.patch("superset.tasks.cache.fetch_csrf_token")
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
@mock.patch("superset.tasks.cache.is_secure_url")
def test_fetch_url(
mock_is_secure_url,
mock_urlopen,
mock_request_cls,
mock_fetch_csrf_token,
base_url,
expected_referer,
):
from superset.tasks.cache import fetch_url
mock_request = mock.MagicMock()
mock_request_cls.return_value = mock_request
mock_urlopen.return_value = mock.MagicMock()
mock_urlopen.return_value.code = 200
# Mock the URL validation to return True for HTTPS and False for HTTP
mock_is_secure_url.return_value = base_url.startswith("https")
initial_headers = {"Cookie": "cookie", "key": "value"}
csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
# Conditionally add the Referer header and assert its presence
if expected_referer:
csrf_headers = csrf_headers | {"Referer": expected_referer}
assert csrf_headers["Referer"] == expected_referer
mock_fetch_csrf_token.return_value = csrf_headers
app.config["WEBDRIVER_BASEURL"] = base_url
data = "data"
data_encoded = b"data"
result = fetch_url(data, initial_headers)
expected_url = (
f"{base_url}/api/v1/chart/warm_up_cache"
if not base_url.endswith("/")
else f"{base_url}api/v1/chart/warm_up_cache"
)
mock_fetch_csrf_token.assert_called_once_with(initial_headers)
mock_request_cls.assert_called_once_with(
expected_url, # Use the dynamic URL based on base_url
data=data_encoded,
headers=csrf_headers,
method="PUT",
)
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
assert data == result["success"]

View File

@@ -1,77 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from tests.integration_tests.test_app import app
@pytest.mark.parametrize(
"base_url",
[
"http://base-url",
"http://base-url/",
"https://base-url",
"https://base-url/",
],
ids=[
"Without trailing slash (HTTP)",
"With trailing slash (HTTP)",
"Without trailing slash (HTTPS)",
"With trailing slash (HTTPS)",
],
)
@mock.patch("superset.tasks.cache.request.Request")
@mock.patch("superset.tasks.cache.request.urlopen")
def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context):
from superset.tasks.utils import fetch_csrf_token
mock_request = mock.MagicMock()
mock_request_cls.return_value = mock_request
mock_response = mock.MagicMock()
mock_urlopen.return_value.__enter__.return_value = mock_response
mock_response.status = 200
mock_response.read.return_value = b'{"result": "csrf_token"}'
mock_response.headers.get_all.return_value = [
"session=new_session_cookie",
"async-token=websocket_cookie",
]
app.config["WEBDRIVER_BASEURL"] = base_url
headers = {"Cookie": "original_session_cookie"}
result_headers = fetch_csrf_token(headers)
expected_url = (
f"{base_url}/api/v1/security/csrf_token/"
if not base_url.endswith("/")
else f"{base_url}api/v1/security/csrf_token/"
)
mock_request_cls.assert_called_with(
expected_url,
headers=headers,
method="GET",
)
assert result_headers["X-CSRF-Token"] == "csrf_token"
assert result_headers["Cookie"] == "session=new_session_cookie" # Updated assertion
# assert the same Request object is used
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)

View File

@@ -147,31 +147,31 @@ class TestWebDriverSelenium(SupersetTestCase):
def test_screenshot_selenium_headstart(
self, mock_sleep, mock_webdriver, mock_webdriver_wait
):
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
app.config["SCREENSHOT_SELENIUM_HEADSTART"] = 5
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_sleep.call_args_list[0] == call(5)
@patch("superset.utils.webdriver.WebDriverWait")
@patch("superset.utils.webdriver.firefox")
def test_screenshot_selenium_locate_wait(self, mock_webdriver, mock_webdriver_wait):
app.config["SCREENSHOT_LOCATE_WAIT"] = 15
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_webdriver_wait.call_args_list[0] == call(ANY, 15)
@patch("superset.utils.webdriver.WebDriverWait")
@patch("superset.utils.webdriver.firefox")
def test_screenshot_selenium_load_wait(self, mock_webdriver, mock_webdriver_wait):
app.config["SCREENSHOT_LOAD_WAIT"] = 15
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_webdriver_wait.call_args_list[2] == call(ANY, 15)
@patch("superset.utils.webdriver.WebDriverWait")
@@ -180,11 +180,11 @@ class TestWebDriverSelenium(SupersetTestCase):
def test_screenshot_selenium_animation_wait(
self, mock_sleep, mock_webdriver, mock_webdriver_wait
):
webdriver = WebDriverSelenium("firefox")
user = security_manager.get_user_by_username(ADMIN_USERNAME)
webdriver = WebDriverSelenium("firefox", user=user)
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] = 4
webdriver.get_screenshot(url, "chart-container", user=user)
webdriver.get_screenshot(url, "chart-container")
assert mock_sleep.call_args_list[1] == call(4)