mirror of
https://github.com/apache/superset.git
synced 2026-05-06 16:34:32 +00:00
Compare commits
6 Commits
dashboard-
...
adopt-pr-3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fae0ff2f1 | ||
|
|
c5475c3a2d | ||
|
|
7476838f26 | ||
|
|
6b4ab27f01 | ||
|
|
776ab3cb14 | ||
|
|
2ac03e438c |
@@ -86,6 +86,39 @@ instead requires a cachelib object.
|
||||
|
||||
See [Async Queries via Celery](/admin-docs/configuration/async-queries-celery) for details.
|
||||
|
||||
## Celery beat
|
||||
|
||||
Superset has a Celery task that will periodically warm up the cache based on different strategies.
|
||||
To use it, add the following to your `superset_config.py`:
|
||||
|
||||
```python
|
||||
from celery.schedules import crontab
|
||||
from superset.config import CeleryConfig
|
||||
|
||||
# User that will be used to authenticate and render dashboards for cache warmup
|
||||
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"
|
||||
|
||||
# Extend the default CeleryConfig to add cache warmup schedule
|
||||
class CustomCeleryConfig(CeleryConfig):
|
||||
beat_schedule = {
|
||||
**CeleryConfig.beat_schedule,
|
||||
'cache-warmup-hourly': {
|
||||
'task': 'cache-warmup',
|
||||
'schedule': crontab(minute=0, hour='*'), # hourly
|
||||
'kwargs': {
|
||||
'strategy_name': 'top_n_dashboards',
|
||||
'top_n': 5,
|
||||
'since': '7 days ago',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
CELERY_CONFIG = CustomCeleryConfig
|
||||
```
|
||||
|
||||
This will cache the top 5 most popular dashboards every hour. For other
|
||||
strategies, check the `superset/tasks/cache.py` file.
|
||||
|
||||
## Caching Thumbnails
|
||||
|
||||
This is an optional feature that can be turned on by activating its [feature flag](/admin-docs/configuration/configuring-superset#feature-flags) on config:
|
||||
|
||||
@@ -1056,6 +1056,11 @@ THUMBNAIL_CACHE_CONFIG: CacheConfig = {
|
||||
}
|
||||
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
|
||||
|
||||
# Cache warmup user — must be set explicitly before enabling the cache-warmup
|
||||
# Celery task. Intentionally defaults to None so operators pick a dedicated
|
||||
# least-privilege user rather than inadvertently running warmup as "admin".
|
||||
SUPERSET_CACHE_WARMUP_USER: str | None = None
|
||||
|
||||
# Time before selenium times out after trying to locate an element on the page and wait
|
||||
# for that element to load for a screenshot.
|
||||
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())
|
||||
|
||||
@@ -228,7 +228,11 @@ class WebDriverPool:
|
||||
def _destroy_driver(self, pooled_driver: PooledWebDriver) -> None:
|
||||
"""Safely destroy a WebDriver instance"""
|
||||
try:
|
||||
WebDriverSelenium.destroy(pooled_driver.driver)
|
||||
try:
|
||||
pooled_driver.driver.close()
|
||||
except Exception: # pylint: disable=broad-except # noqa: S110
|
||||
pass
|
||||
pooled_driver.driver.quit()
|
||||
self._stats["destroyed"] += 1
|
||||
logger.debug("Destroyed WebDriver instance")
|
||||
except Exception as e:
|
||||
|
||||
@@ -17,65 +17,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, TypedDict, Union
|
||||
from urllib import request
|
||||
from urllib.error import URLError
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from celery.beat import SchedulingError
|
||||
from celery.utils.log import get_task_logger
|
||||
from flask import current_app
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.core import Log
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.tags.models import Tag, TaggedObject
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError, InvalidExecutorError
|
||||
from superset.tasks.utils import fetch_csrf_token, get_executor
|
||||
from superset.utils import json
|
||||
from superset.utils.date_parser import parse_human_datetime
|
||||
from superset.utils.machine_auth import MachineAuthProvider
|
||||
from superset.utils.urls import get_url_path, is_secure_url
|
||||
from superset.utils.webdriver import WebDriverSelenium
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class CacheWarmupPayload(TypedDict, total=False):
|
||||
chart_id: int
|
||||
dashboard_id: int | None
|
||||
|
||||
|
||||
class CacheWarmupTask(TypedDict):
|
||||
payload: CacheWarmupPayload
|
||||
username: str | None
|
||||
|
||||
|
||||
def get_task(chart: Slice, dashboard: Optional[Dashboard] = None) -> CacheWarmupTask:
|
||||
"""Return task for warming up a given chart/table cache."""
|
||||
executors = current_app.config["CACHE_WARMUP_EXECUTORS"]
|
||||
payload: CacheWarmupPayload = {"chart_id": chart.id}
|
||||
if dashboard:
|
||||
payload["dashboard_id"] = dashboard.id
|
||||
|
||||
username: str | None
|
||||
try:
|
||||
executor = get_executor(executors, chart)
|
||||
username = executor[1]
|
||||
except (ExecutorNotFoundError, InvalidExecutorError):
|
||||
username = None
|
||||
|
||||
return {"payload": payload, "username": username}
|
||||
def get_dash_url(dashboard: Dashboard) -> str:
|
||||
"""Return external URL for warming up a given dashboard cache."""
|
||||
with current_app.test_request_context():
|
||||
baseurl = (
|
||||
# when running this as an async task, drop the request context with
|
||||
# app.test_request_context()
|
||||
current_app.config.get("WEBDRIVER_BASEURL")
|
||||
or "{SUPERSET_WEBSERVER_PROTOCOL}://"
|
||||
"{SUPERSET_WEBSERVER_ADDRESS}:"
|
||||
"{SUPERSET_WEBSERVER_PORT}".format(**current_app.config)
|
||||
)
|
||||
return f"{baseurl.rstrip('/')}{dashboard.url}"
|
||||
|
||||
|
||||
class Strategy: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A cache warm up strategy.
|
||||
|
||||
Each strategy defines a `get_tasks` method that returns a list of tasks to
|
||||
send to the `/api/v1/chart/warm_up_cache` endpoint.
|
||||
Each strategy defines a `get_urls` method that returns a list of dashboard URLs to
|
||||
warm up using WebDriver.
|
||||
|
||||
Strategies can be configured in `superset/config.py`:
|
||||
|
||||
@@ -96,15 +77,16 @@ class Strategy: # pylint: disable=too-few-public-methods
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_tasks(self) -> list[CacheWarmupTask]:
|
||||
raise NotImplementedError("Subclasses must implement get_tasks!")
|
||||
def get_urls(self) -> list[str]:
|
||||
raise NotImplementedError("Subclasses must implement get_urls!")
|
||||
|
||||
|
||||
class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Warm up all charts.
|
||||
Warm up all published dashboards.
|
||||
|
||||
This is a dummy strategy that will fetch all charts. Can be configured by:
|
||||
This is a dummy strategy that will fetch all published dashboards.
|
||||
Can be configured by:
|
||||
|
||||
beat_schedule = {
|
||||
'cache-warmup-hourly': {
|
||||
@@ -118,8 +100,16 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
|
||||
name = "dummy"
|
||||
|
||||
def get_tasks(self) -> list[CacheWarmupTask]:
|
||||
return [get_task(chart) for chart in db.session.query(Slice).all()]
|
||||
def get_urls(self) -> list[str]:
|
||||
# Use selectinload to avoid N+1 queries when checking dashboard.slices
|
||||
dashboards = (
|
||||
db.session.query(Dashboard)
|
||||
.options(selectinload(Dashboard.slices))
|
||||
.filter(Dashboard.published.is_(True))
|
||||
.all()
|
||||
)
|
||||
|
||||
return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]
|
||||
|
||||
|
||||
class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
@@ -147,7 +137,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
|
||||
self.top_n = top_n
|
||||
self.since = parse_human_datetime(since) if since else None
|
||||
|
||||
def get_tasks(self) -> list[CacheWarmupTask]:
|
||||
def get_urls(self) -> list[str]:
|
||||
records = (
|
||||
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
|
||||
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
|
||||
@@ -161,11 +151,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
|
||||
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
|
||||
)
|
||||
|
||||
return [
|
||||
get_task(chart, dashboard)
|
||||
for dashboard in dashboards
|
||||
for chart in dashboard.slices
|
||||
]
|
||||
return [get_dash_url(dashboard) for dashboard in dashboards]
|
||||
|
||||
|
||||
class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
@@ -190,8 +176,8 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
super().__init__()
|
||||
self.tags = tags or []
|
||||
|
||||
def get_tasks(self) -> list[CacheWarmupTask]:
|
||||
tasks = []
|
||||
def get_urls(self) -> list[str]:
|
||||
urls = []
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
@@ -211,73 +197,14 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
|
||||
Dashboard.id.in_(dash_ids)
|
||||
)
|
||||
for dashboard in tagged_dashboards:
|
||||
for chart in dashboard.slices:
|
||||
tasks.append(get_task(chart))
|
||||
urls.append(get_dash_url(dashboard))
|
||||
|
||||
# add charts that are tagged
|
||||
tagged_objects = (
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "chart",
|
||||
TaggedObject.tag_id.in_(tag_ids),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
|
||||
for chart in tagged_charts:
|
||||
tasks.append(get_task(chart))
|
||||
|
||||
return tasks
|
||||
return urls
|
||||
|
||||
|
||||
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
|
||||
|
||||
|
||||
@celery_app.task(name="fetch_url")
|
||||
def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Celery job to fetch url
|
||||
"""
|
||||
result = {}
|
||||
try:
|
||||
url = get_url_path("ChartRestApi.warm_up_cache")
|
||||
|
||||
if is_secure_url(url):
|
||||
logger.info("URL '%s' is secure. Adding Referer header.", url)
|
||||
headers.update({"Referer": url})
|
||||
|
||||
# Fetch CSRF token for API request
|
||||
headers.update(fetch_csrf_token(headers))
|
||||
|
||||
logger.info("Fetching %s with payload %s", url, data)
|
||||
req = request.Request( # noqa: S310
|
||||
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
|
||||
)
|
||||
response = request.urlopen( # pylint: disable=consider-using-with # noqa: S310
|
||||
req, timeout=600
|
||||
)
|
||||
logger.info(
|
||||
"Fetched %s with payload %s, status code: %s", url, data, response.code
|
||||
)
|
||||
if response.code == 200:
|
||||
result = {"success": data, "response": response.read().decode("utf-8")}
|
||||
else:
|
||||
result = {"error": data, "status_code": response.code}
|
||||
logger.error(
|
||||
"Error fetching %s with payload %s, status code: %s",
|
||||
url,
|
||||
data,
|
||||
response.code,
|
||||
)
|
||||
except URLError as err:
|
||||
logger.exception("Error warming up cache!")
|
||||
result = {"error": data, "exception": str(err)}
|
||||
return result
|
||||
|
||||
|
||||
@celery_app.task(name="cache-warmup")
|
||||
def cache_warmup(
|
||||
strategy_name: str, *args: Any, **kwargs: Any
|
||||
@@ -285,7 +212,7 @@ def cache_warmup(
|
||||
"""
|
||||
Warm up cache.
|
||||
|
||||
This task periodically hits charts to warm up the cache.
|
||||
This task periodically hits dashboards to warm up the cache.
|
||||
|
||||
"""
|
||||
logger.info("Loading strategy")
|
||||
@@ -307,25 +234,39 @@ def cache_warmup(
|
||||
logger.exception(message)
|
||||
return message
|
||||
|
||||
results: dict[str, list[str]] = {"scheduled": [], "errors": []}
|
||||
for task in strategy.get_tasks():
|
||||
username = task["username"]
|
||||
payload = json.dumps(task["payload"])
|
||||
if username:
|
||||
results: dict[str, list[str]] = {"success": [], "errors": []}
|
||||
|
||||
warmup_username = current_app.config.get("SUPERSET_CACHE_WARMUP_USER")
|
||||
if not warmup_username:
|
||||
message = (
|
||||
"SUPERSET_CACHE_WARMUP_USER is not configured. Set it to a dedicated "
|
||||
"least-privilege user with access to the dashboards you want warmed up."
|
||||
)
|
||||
logger.error(message)
|
||||
return message
|
||||
|
||||
user = security_manager.find_user(username=warmup_username)
|
||||
if not user:
|
||||
message = (
|
||||
f"Cache warmup user '{warmup_username}' not found. Please configure "
|
||||
"SUPERSET_CACHE_WARMUP_USER with a valid username."
|
||||
)
|
||||
logger.error(message)
|
||||
return message
|
||||
|
||||
wd = WebDriverSelenium(current_app.config["WEBDRIVER_TYPE"], user=user)
|
||||
|
||||
try:
|
||||
for url in strategy.get_urls():
|
||||
try:
|
||||
user = security_manager.get_user_by_username(username)
|
||||
cookies = MachineAuthProvider.get_auth_cookies(user)
|
||||
headers = {
|
||||
"Cookie": "session=%s" % cookies.get("session", ""),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
logger.info("Scheduling %s", payload)
|
||||
fetch_url.delay(payload, headers)
|
||||
results["scheduled"].append(payload)
|
||||
except SchedulingError:
|
||||
logger.exception("Error scheduling fetch_url for payload: %s", payload)
|
||||
results["errors"].append(payload)
|
||||
else:
|
||||
logger.warning("Executor not found for %s", payload)
|
||||
logger.info("Fetching %s", url)
|
||||
wd.get_screenshot(url, "grid-container")
|
||||
results["success"].append(url)
|
||||
except (WebDriverException, Exception) as ex: # noqa: BLE001
|
||||
logger.exception("Error warming up cache for %s: %s", url, ex)
|
||||
results["errors"].append(url)
|
||||
finally:
|
||||
# Ensure WebDriver is properly cleaned up
|
||||
wd.destroy()
|
||||
|
||||
return results
|
||||
|
||||
@@ -33,8 +33,8 @@ from superset.utils.urls import modify_url_query
|
||||
from superset.utils.webdriver import (
|
||||
ChartStandaloneMode,
|
||||
DashboardStandaloneMode,
|
||||
WebDriver,
|
||||
WebDriverPlaywright,
|
||||
WebDriverProxy,
|
||||
WebDriverSelenium,
|
||||
WindowSize,
|
||||
)
|
||||
@@ -188,7 +188,9 @@ class BaseScreenshot:
|
||||
self.url = url
|
||||
self.screenshot = None
|
||||
|
||||
def driver(self, window_size: WindowSize | None = None) -> WebDriver:
|
||||
def driver(
|
||||
self, window_size: WindowSize | None = None, user: User | None = None
|
||||
) -> WebDriverProxy:
|
||||
window_size = window_size or self.window_size
|
||||
if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
|
||||
# Try to use Playwright if available (supports WebGL/DeckGL, unlike Cypress)
|
||||
@@ -204,13 +206,17 @@ class BaseScreenshot:
|
||||
)
|
||||
|
||||
# Use Selenium as default/fallback
|
||||
return WebDriverSelenium(self.driver_type, window_size)
|
||||
return WebDriverSelenium(self.driver_type, window_size, user)
|
||||
|
||||
def get_screenshot(
|
||||
self, user: User, window_size: WindowSize | None = None
|
||||
) -> bytes | None:
|
||||
driver = self.driver(window_size)
|
||||
self.screenshot = driver.get_screenshot(self.url, self.element, user)
|
||||
driver = self.driver(window_size, user)
|
||||
try:
|
||||
self.screenshot = driver.get_screenshot(self.url, self.element, user)
|
||||
finally:
|
||||
if isinstance(driver, WebDriverSelenium):
|
||||
driver.destroy()
|
||||
return self.screenshot
|
||||
|
||||
def get_cache_key(
|
||||
|
||||
@@ -159,7 +159,9 @@ class WebDriverProxy(ABC):
|
||||
self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"]
|
||||
|
||||
@abstractmethod
|
||||
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
|
||||
def get_screenshot(
|
||||
self, url: str, element_name: str, user: User | None = None
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Run webdriver and return a screenshot
|
||||
"""
|
||||
@@ -224,7 +226,7 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
return element.screenshot()
|
||||
|
||||
def get_screenshot( # pylint: disable=too-many-locals, too-many-statements # noqa: C901
|
||||
self, url: str, element_name: str, user: User
|
||||
self, url: str, element_name: str, user: User | None = None
|
||||
) -> bytes | None:
|
||||
if not PLAYWRIGHT_AVAILABLE:
|
||||
logger.info(
|
||||
@@ -252,7 +254,8 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
context.set_default_timeout(
|
||||
app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"]
|
||||
)
|
||||
self.auth(user, context)
|
||||
if user:
|
||||
self.auth(user, context)
|
||||
page = context.new_page()
|
||||
try:
|
||||
page.goto(
|
||||
@@ -318,7 +321,7 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
logger.debug(
|
||||
"Taking a PNG screenshot of url %s as user %s",
|
||||
url,
|
||||
user.username,
|
||||
user.username if user else "None",
|
||||
)
|
||||
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
|
||||
unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page)
|
||||
@@ -399,6 +402,30 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
|
||||
|
||||
class WebDriverSelenium(WebDriverProxy):
|
||||
def __init__(
|
||||
self,
|
||||
driver_type: str,
|
||||
window: WindowSize | None = None,
|
||||
user: User | None = None,
|
||||
):
|
||||
super().__init__(driver_type, window)
|
||||
self._user = user
|
||||
self._driver: WebDriver | None = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
self._destroy()
|
||||
|
||||
@property
|
||||
def driver(self) -> WebDriver:
|
||||
if not self._driver:
|
||||
self._driver = self._create()
|
||||
if not self._driver:
|
||||
raise RuntimeError("WebDriver creation failed")
|
||||
self._driver.set_window_size(*self._window)
|
||||
if self._user:
|
||||
self._auth(self._user)
|
||||
return self._driver
|
||||
|
||||
def _create_firefox_driver(
|
||||
self, pixel_density: float
|
||||
) -> tuple[type[WebDriver], type[Service], dict[str, Any]]:
|
||||
@@ -456,6 +483,22 @@ class WebDriverSelenium(WebDriverProxy):
|
||||
return config
|
||||
|
||||
def create(self) -> WebDriver:
|
||||
"""Create and return the WebDriver instance.
|
||||
|
||||
This is the public interface for creating the driver. It wraps
|
||||
the internal _create method for backward compatibility.
|
||||
"""
|
||||
return self._create()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy the WebDriver instance.
|
||||
|
||||
This is the public interface for cleanup. It wraps the internal
|
||||
_destroy method and should be called when done with the driver.
|
||||
"""
|
||||
self._destroy()
|
||||
|
||||
def _create(self) -> WebDriver:
|
||||
pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
|
||||
|
||||
# Get driver class and initial kwargs based on driver type
|
||||
@@ -516,25 +559,29 @@ class WebDriverSelenium(WebDriverProxy):
|
||||
logger.debug("Init selenium driver")
|
||||
return driver_class(**kwargs)
|
||||
|
||||
def auth(self, user: User) -> WebDriver:
|
||||
driver = self.create()
|
||||
return machine_auth_provider_factory.instance.authenticate_webdriver(
|
||||
driver, user
|
||||
def _auth(self, user: User) -> None:
|
||||
"""Authenticate the persistent driver in-place."""
|
||||
if self._driver is None:
|
||||
raise RuntimeError("WebDriver is not initialized")
|
||||
machine_auth_provider_factory.instance.authenticate_webdriver(
|
||||
self._driver, user
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def destroy(driver: WebDriver, tries: int = 2) -> None:
|
||||
"""Destroy a driver"""
|
||||
def _destroy(self, tries: int = 2) -> None:
|
||||
"""Destroy the persistent driver"""
|
||||
if not self._driver:
|
||||
return
|
||||
# This is some very flaky code in selenium. Hence the retries
|
||||
# and catch-all exceptions
|
||||
try:
|
||||
retry_call(driver.close, max_tries=tries)
|
||||
retry_call(self._driver.close, max_tries=tries)
|
||||
except Exception: # pylint: disable=broad-except # noqa: S110
|
||||
pass
|
||||
try:
|
||||
driver.quit()
|
||||
self._driver.quit()
|
||||
except Exception: # pylint: disable=broad-except # noqa: S110
|
||||
pass
|
||||
self._driver = None
|
||||
|
||||
@staticmethod
|
||||
def find_unexpected_errors(driver: WebDriver) -> list[str]:
|
||||
@@ -592,9 +639,16 @@ class WebDriverSelenium(WebDriverProxy):
|
||||
|
||||
return error_messages
|
||||
|
||||
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
|
||||
driver = self.auth(user)
|
||||
driver.set_window_size(*self._window)
|
||||
def get_screenshot( # noqa: C901
|
||||
self, url: str, element_name: str, user: User | None = None
|
||||
) -> bytes | None:
|
||||
# If a user is passed explicitly and differs from the stored user,
|
||||
# update and re-authenticate
|
||||
if user and user != self._user:
|
||||
self._user = user
|
||||
if self._driver:
|
||||
self._destroy()
|
||||
driver = self.driver
|
||||
driver.get(url)
|
||||
img: bytes | None = None
|
||||
selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"]
|
||||
@@ -663,7 +717,7 @@ class WebDriverSelenium(WebDriverProxy):
|
||||
logger.debug(
|
||||
"Taking a PNG screenshot of url %s as user %s",
|
||||
url,
|
||||
user.username,
|
||||
self._user.username if self._user else "None",
|
||||
)
|
||||
|
||||
if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]:
|
||||
@@ -698,5 +752,9 @@ class WebDriverSelenium(WebDriverProxy):
|
||||
logger.warning("exception in webdriver", exc_info=ex)
|
||||
raise
|
||||
finally:
|
||||
self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"])
|
||||
# When used as a persistent driver (e.g., cache warmup),
|
||||
# cleanup is handled externally via destroy().
|
||||
# When used for one-off screenshots, the caller or __del__
|
||||
# handles cleanup.
|
||||
pass
|
||||
return img
|
||||
|
||||
@@ -82,15 +82,9 @@ class TestCacheWarmUp(SupersetTestCase):
|
||||
self.client.get(f"/superset/dashboard/{dash.id}/")
|
||||
|
||||
strategy = TopNDashboardsStrategy(1)
|
||||
result = strategy.get_tasks()
|
||||
expected = [
|
||||
{
|
||||
"payload": {"chart_id": chart.id, "dashboard_id": dash.id},
|
||||
"username": "admin",
|
||||
}
|
||||
for chart in dash.slices
|
||||
]
|
||||
assert len(result) == len(expected)
|
||||
result = sorted(strategy.get_urls())
|
||||
expected = sorted([f"{get_url_host().rstrip('/')}{dash.url}"])
|
||||
assert result == expected
|
||||
|
||||
def reset_tag(self, tag):
|
||||
"""Remove associated object from tag, used to reset tests"""
|
||||
@@ -108,39 +102,27 @@ class TestCacheWarmUp(SupersetTestCase):
|
||||
self.reset_tag(tag1)
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag1"])
|
||||
assert strategy.get_tasks() == []
|
||||
result = sorted(strategy.get_urls())
|
||||
expected = []
|
||||
assert result == expected
|
||||
|
||||
# tag dashboard 'births' with `tag1`
|
||||
tag1 = get_tag("tag1", db.session, TagType.custom)
|
||||
dash = self.get_dash_by_slug("births")
|
||||
tag1_payloads = [{"chart_id": chart.id} for chart in dash.slices]
|
||||
tag1_urls = [f"{get_url_host().rstrip('/')}{dash.url}"]
|
||||
tagged_object = TaggedObject(
|
||||
tag_id=tag1.id, object_id=dash.id, object_type=ObjectType.dashboard
|
||||
)
|
||||
db.session.add(tagged_object)
|
||||
db.session.commit()
|
||||
|
||||
assert len(strategy.get_tasks()) == len(tag1_payloads)
|
||||
result = sorted(strategy.get_urls())
|
||||
assert result == tag1_urls
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag2"])
|
||||
tag2 = get_tag("tag2", db.session, TagType.custom)
|
||||
self.reset_tag(tag2)
|
||||
|
||||
assert strategy.get_tasks() == []
|
||||
|
||||
# tag first slice
|
||||
dash = self.get_dash_by_slug("unicode-test")
|
||||
chart = dash.slices[0]
|
||||
tag2_payloads = [{"chart_id": chart.id}]
|
||||
object_id = chart.id
|
||||
tagged_object = TaggedObject(
|
||||
tag_id=tag2.id, object_id=object_id, object_type=ObjectType.chart
|
||||
)
|
||||
db.session.add(tagged_object)
|
||||
db.session.commit()
|
||||
|
||||
assert len(strategy.get_tasks()) == len(tag2_payloads)
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag1", "tag2"])
|
||||
|
||||
assert len(strategy.get_tasks()) == len(tag1_payloads + tag2_payloads)
|
||||
result = sorted(strategy.get_urls())
|
||||
expected = []
|
||||
assert result == expected
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url, expected_referer",
|
||||
[
|
||||
("http://base-url", None),
|
||||
("http://base-url/", None),
|
||||
("https://base-url", "https://base-url/api/v1/chart/warm_up_cache"),
|
||||
("https://base-url/", "https://base-url/api/v1/chart/warm_up_cache"),
|
||||
],
|
||||
ids=[
|
||||
"Without trailing slash (HTTP)",
|
||||
"With trailing slash (HTTP)",
|
||||
"Without trailing slash (HTTPS)",
|
||||
"With trailing slash (HTTPS)",
|
||||
],
|
||||
)
|
||||
@mock.patch("superset.tasks.cache.fetch_csrf_token")
|
||||
@mock.patch("superset.tasks.cache.request.Request")
|
||||
@mock.patch("superset.tasks.cache.request.urlopen")
|
||||
@mock.patch("superset.tasks.cache.is_secure_url")
|
||||
def test_fetch_url(
|
||||
mock_is_secure_url,
|
||||
mock_urlopen,
|
||||
mock_request_cls,
|
||||
mock_fetch_csrf_token,
|
||||
base_url,
|
||||
expected_referer,
|
||||
):
|
||||
from superset.tasks.cache import fetch_url
|
||||
|
||||
mock_request = mock.MagicMock()
|
||||
mock_request_cls.return_value = mock_request
|
||||
|
||||
mock_urlopen.return_value = mock.MagicMock()
|
||||
mock_urlopen.return_value.code = 200
|
||||
|
||||
# Mock the URL validation to return True for HTTPS and False for HTTP
|
||||
mock_is_secure_url.return_value = base_url.startswith("https")
|
||||
|
||||
initial_headers = {"Cookie": "cookie", "key": "value"}
|
||||
csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
|
||||
|
||||
# Conditionally add the Referer header and assert its presence
|
||||
if expected_referer:
|
||||
csrf_headers = csrf_headers | {"Referer": expected_referer}
|
||||
assert csrf_headers["Referer"] == expected_referer
|
||||
|
||||
mock_fetch_csrf_token.return_value = csrf_headers
|
||||
|
||||
app.config["WEBDRIVER_BASEURL"] = base_url
|
||||
data = "data"
|
||||
data_encoded = b"data"
|
||||
|
||||
result = fetch_url(data, initial_headers)
|
||||
|
||||
expected_url = (
|
||||
f"{base_url}/api/v1/chart/warm_up_cache"
|
||||
if not base_url.endswith("/")
|
||||
else f"{base_url}api/v1/chart/warm_up_cache"
|
||||
)
|
||||
|
||||
mock_fetch_csrf_token.assert_called_once_with(initial_headers)
|
||||
|
||||
mock_request_cls.assert_called_once_with(
|
||||
expected_url, # Use the dynamic URL based on base_url
|
||||
data=data_encoded,
|
||||
headers=csrf_headers,
|
||||
method="PUT",
|
||||
)
|
||||
# assert the same Request object is used
|
||||
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
|
||||
|
||||
assert data == result["success"]
|
||||
@@ -1,77 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"http://base-url",
|
||||
"http://base-url/",
|
||||
"https://base-url",
|
||||
"https://base-url/",
|
||||
],
|
||||
ids=[
|
||||
"Without trailing slash (HTTP)",
|
||||
"With trailing slash (HTTP)",
|
||||
"Without trailing slash (HTTPS)",
|
||||
"With trailing slash (HTTPS)",
|
||||
],
|
||||
)
|
||||
@mock.patch("superset.tasks.cache.request.Request")
|
||||
@mock.patch("superset.tasks.cache.request.urlopen")
|
||||
def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context):
|
||||
from superset.tasks.utils import fetch_csrf_token
|
||||
|
||||
mock_request = mock.MagicMock()
|
||||
mock_request_cls.return_value = mock_request
|
||||
|
||||
mock_response = mock.MagicMock()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b'{"result": "csrf_token"}'
|
||||
mock_response.headers.get_all.return_value = [
|
||||
"session=new_session_cookie",
|
||||
"async-token=websocket_cookie",
|
||||
]
|
||||
|
||||
app.config["WEBDRIVER_BASEURL"] = base_url
|
||||
headers = {"Cookie": "original_session_cookie"}
|
||||
|
||||
result_headers = fetch_csrf_token(headers)
|
||||
|
||||
expected_url = (
|
||||
f"{base_url}/api/v1/security/csrf_token/"
|
||||
if not base_url.endswith("/")
|
||||
else f"{base_url}api/v1/security/csrf_token/"
|
||||
)
|
||||
|
||||
mock_request_cls.assert_called_with(
|
||||
expected_url,
|
||||
headers=headers,
|
||||
method="GET",
|
||||
)
|
||||
|
||||
assert result_headers["X-CSRF-Token"] == "csrf_token"
|
||||
assert result_headers["Cookie"] == "session=new_session_cookie" # Updated assertion
|
||||
# assert the same Request object is used
|
||||
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
|
||||
@@ -147,31 +147,31 @@ class TestWebDriverSelenium(SupersetTestCase):
|
||||
def test_screenshot_selenium_headstart(
|
||||
self, mock_sleep, mock_webdriver, mock_webdriver_wait
|
||||
):
|
||||
webdriver = WebDriverSelenium("firefox")
|
||||
user = security_manager.get_user_by_username(ADMIN_USERNAME)
|
||||
webdriver = WebDriverSelenium("firefox", user=user)
|
||||
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
|
||||
app.config["SCREENSHOT_SELENIUM_HEADSTART"] = 5
|
||||
webdriver.get_screenshot(url, "chart-container", user=user)
|
||||
webdriver.get_screenshot(url, "chart-container")
|
||||
assert mock_sleep.call_args_list[0] == call(5)
|
||||
|
||||
@patch("superset.utils.webdriver.WebDriverWait")
|
||||
@patch("superset.utils.webdriver.firefox")
|
||||
def test_screenshot_selenium_locate_wait(self, mock_webdriver, mock_webdriver_wait):
|
||||
app.config["SCREENSHOT_LOCATE_WAIT"] = 15
|
||||
webdriver = WebDriverSelenium("firefox")
|
||||
user = security_manager.get_user_by_username(ADMIN_USERNAME)
|
||||
webdriver = WebDriverSelenium("firefox", user=user)
|
||||
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
|
||||
webdriver.get_screenshot(url, "chart-container", user=user)
|
||||
webdriver.get_screenshot(url, "chart-container")
|
||||
assert mock_webdriver_wait.call_args_list[0] == call(ANY, 15)
|
||||
|
||||
@patch("superset.utils.webdriver.WebDriverWait")
|
||||
@patch("superset.utils.webdriver.firefox")
|
||||
def test_screenshot_selenium_load_wait(self, mock_webdriver, mock_webdriver_wait):
|
||||
app.config["SCREENSHOT_LOAD_WAIT"] = 15
|
||||
webdriver = WebDriverSelenium("firefox")
|
||||
user = security_manager.get_user_by_username(ADMIN_USERNAME)
|
||||
webdriver = WebDriverSelenium("firefox", user=user)
|
||||
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
|
||||
webdriver.get_screenshot(url, "chart-container", user=user)
|
||||
webdriver.get_screenshot(url, "chart-container")
|
||||
assert mock_webdriver_wait.call_args_list[2] == call(ANY, 15)
|
||||
|
||||
@patch("superset.utils.webdriver.WebDriverWait")
|
||||
@@ -180,11 +180,11 @@ class TestWebDriverSelenium(SupersetTestCase):
|
||||
def test_screenshot_selenium_animation_wait(
|
||||
self, mock_sleep, mock_webdriver, mock_webdriver_wait
|
||||
):
|
||||
webdriver = WebDriverSelenium("firefox")
|
||||
user = security_manager.get_user_by_username(ADMIN_USERNAME)
|
||||
webdriver = WebDriverSelenium("firefox", user=user)
|
||||
url = get_url_path("Superset.slice", slice_id=1, standalone="true")
|
||||
app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] = 4
|
||||
webdriver.get_screenshot(url, "chart-container", user=user)
|
||||
webdriver.get_screenshot(url, "chart-container")
|
||||
assert mock_sleep.call_args_list[1] == call(4)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user