diff --git a/docs/docs/contributing/development.mdx b/docs/docs/contributing/development.mdx index f50c4e724ca..ee38ca76349 100644 --- a/docs/docs/contributing/development.mdx +++ b/docs/docs/contributing/development.mdx @@ -421,14 +421,6 @@ Then make sure you run your WSGI server using the right worker type: gunicorn "superset.app:create_app()" -k "geventwebsocket.gunicorn.workers.GeventWebSocketWorker" -b 127.0.0.1:8088 --reload ``` -You can log anything to the browser console, including objects: - -```python -from superset import app -app.logger.error('An exception occurred!') -app.logger.info(form_data) -``` - ### Frontend Frontend assets (TypeScript, JavaScript, CSS, and images) must be compiled in order to properly display the web UI. The `superset-frontend` directory contains all NPM-managed frontend assets. Note that for some legacy pages there are additional frontend assets bundled with Flask-Appbuilder (e.g. jQuery and bootstrap). These are not managed by NPM and may be phased out in the future. diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index e127d0c0206..6f3f3bddb83 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -33,4 +33,4 @@ superset load-test-users echo "Running tests" -pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@" +pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" diff --git a/superset/__init__.py b/superset/__init__.py index cbab58e0d2c..450b5f104ff 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask import current_app, Flask from werkzeug.local import LocalProxy from superset.app import create_app # noqa: F401 @@ -35,9 +34,7 @@ from superset.security import SupersetSecurityManager # noqa: F401 # to declare "global" dependencies is to define it in extensions.py, # then initialize it in app.create_app(). These fields will be removed # in subsequent PRs as things are migrated towards the factory pattern -app: Flask = current_app cache = cache_manager.cache -conf = LocalProxy(lambda: current_app.config) get_feature_flags = feature_flag_manager.get_feature_flags get_manifest_files = manifest_processor.get_manifest_files is_feature_enabled = feature_flag_manager.is_feature_enabled diff --git a/superset/advanced_data_type/api.py b/superset/advanced_data_type/api.py index 3659429bad9..152aa9b8858 100644 --- a/superset/advanced_data_type/api.py +++ b/superset/advanced_data_type/api.py @@ -29,9 +29,6 @@ from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.extensions import event_logger from superset.views.base_api import BaseSupersetApi -config = app.config -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] - class AdvancedDataTypeRestApi(BaseSupersetApi): """ @@ -96,7 +93,7 @@ class AdvancedDataTypeRestApi(BaseSupersetApi): item = kwargs["rison"] advanced_data_type = item["type"] values = item["values"] - addon = ADVANCED_DATA_TYPES.get(advanced_data_type) + addon = app.config["ADVANCED_DATA_TYPES"].get(advanced_data_type) if not addon: return self.response( 400, @@ -148,4 +145,4 @@ class AdvancedDataTypeRestApi(BaseSupersetApi): 500: $ref: '#/components/responses/500' """ - return self.response(200, result=list(ADVANCED_DATA_TYPES.keys())) + return self.response(200, result=list(app.config["ADVANCED_DATA_TYPES"].keys())) diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index d955cc88278..74bf0f8e0ac 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -117,9 +117,8 @@ class AsyncQueryManager: self._load_explore_json_into_cache_job: Any = None def init_app(self, app: Flask) -> None: - config = app.config - cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE") - data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE") + cache_type = app.config.get("CACHE_CONFIG", {}).get("CACHE_TYPE") + data_cache_type = app.config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE") if cache_type in [None, "null"] or data_cache_type in [None, "null"]: raise Exception( # pylint: disable=broad-exception-raised """ @@ -128,26 +127,28 @@ class AsyncQueryManager: """ ) - self._cache = get_cache_backend(config) + self._cache = get_cache_backend(app.config) logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__) - if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: + if len(app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: raise AsyncQueryTokenException( "Please provide a JWT secret at least 32 bytes long" ) - self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] - self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] - self._stream_limit_firehose = config[ + self._stream_prefix = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self._stream_limit = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] + self._stream_limit_firehose = app.config[ "GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE" ] - self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"] - self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"] - self._jwt_cookie_samesite = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE"] - self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"] - self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] + self._jwt_cookie_name = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"] + self._jwt_cookie_secure = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"] + self._jwt_cookie_samesite = app.config[ + "GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE" + ] + self._jwt_cookie_domain = app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"] + self._jwt_secret = app.config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] - if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: + if app.config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]: self.register_request_handlers(app) # pylint: disable=import-outside-toplevel diff --git a/superset/available_domains/api.py b/superset/available_domains/api.py index 0b8253d942f..17389bd3928 100644 --- a/superset/available_domains/api.py +++ b/superset/available_domains/api.py @@ -16,10 +16,9 @@ # under the License. import logging -from flask import Response +from flask import current_app as app, Response from flask_appbuilder.api import expose, protect, safe -from superset import conf from superset.available_domains.schemas import AvailableDomainsSchema from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP from superset.extensions import event_logger @@ -70,6 +69,6 @@ class AvailableDomainsRestApi(BaseSupersetApi): $ref: '#/components/responses/403' """ result = self.available_domains_schema.dump( - {"domains": conf.get("SUPERSET_WEBSERVER_DOMAINS")} + {"domains": app.config.get("SUPERSET_WEBSERVER_DOMAINS")} ) return self.response(200, result=result) diff --git a/superset/charts/api.py b/superset/charts/api.py index 7a15d6e5c3e..5abead3f6af 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -30,7 +30,7 @@ from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.charts.filters import ( ChartAllTextFilter, ChartCertifiedFilter, @@ -101,7 +101,6 @@ from superset.views.base_api import ( from superset.views.filters import BaseFilterRelatedUsers, FilterRelatedOwners logger = logging.getLogger(__name__) -config = app.config class ChartRestApi(BaseSupersetModelRestApi): diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index efae7de8313..3591539a098 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -20,7 +20,7 @@ import contextlib import logging from typing import Any, TYPE_CHECKING -from flask import current_app, g, make_response, request, Response +from flask import current_app as app, g, make_response, request, Response from flask_appbuilder.api import expose, protect from flask_babel import gettext as _ from marshmallow import ValidationError @@ -379,7 +379,7 @@ class ChartDataRestApi(ChartRestApi): # return multi-query results bundled as a zip file def _process_data(query_data: Any) -> Any: if result_format == ChartDataResultFormat.CSV: - encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") + encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8") return query_data.encode(encoding) return query_data diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index a8911ce13da..f1f7a7a19e7 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -20,11 +20,11 @@ from __future__ import annotations import inspect from typing import Any, TYPE_CHECKING +from flask import current_app from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow.validate import Length, Range -from superset import app from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.db_engine_specs.base import builtin_time_grains from superset.utils import pandas_postprocessing, schema as utils @@ -40,7 +40,25 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.common.query_context_factory import QueryContextFactory -config = app.config + +def get_time_grain_choices() -> Any: + """Get time grain choices including addons from config""" + try: + # Try to get config from current app context + time_grain_addons = current_app.config.get("TIME_GRAIN_ADDONS", {}) + except RuntimeError: + # Outside app context, use empty addons + time_grain_addons = {} + + return [ + i + for i in { + **builtin_time_grains, + **time_grain_addons, + }.keys() + if i + ] + # # RISON/JSON schemas for query parameters @@ -624,13 +642,7 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.", "example": "P1D", }, - validate=validate.OneOf( - choices=[ - i - for i in {**builtin_time_grains, **config["TIME_GRAIN_ADDONS"]}.keys() - if i - ] - ), + validate=validate.OneOf(choices=get_time_grain_choices()), required=True, ) periods = fields.Integer( @@ -989,13 +1001,7 @@ class ChartDataExtrasSchema(Schema): "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.", "example": "P1D", }, - validate=validate.OneOf( - choices=[ - i - for i in {**builtin_time_grains, **config["TIME_GRAIN_ADDONS"]}.keys() - if i - ] - ), + validate=validate.OneOf(choices=get_time_grain_choices()), allow_none=True, ) instant_time_comparison_range = fields.String( diff --git a/superset/cli/lib.py b/superset/cli/lib.py deleted file mode 100755 index bf6d77ce319..00000000000 --- a/superset/cli/lib.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging - -from superset import config - -logger = logging.getLogger(__name__) - - -feature_flags = config.DEFAULT_FEATURE_FLAGS.copy() -feature_flags.update(config.FEATURE_FLAGS) -feature_flags_func = config.GET_FEATURE_FLAGS_FUNC -if feature_flags_func: - try: - # pylint: disable=not-callable - feature_flags = feature_flags_func(feature_flags) - except Exception: # pylint: disable=broad-except # noqa: S110 - # bypass any feature flags that depend on context - # that's not available - pass - - -def normalize_token(token_name: str) -> str: - """ - As of click>=7, underscores in function names are replaced by dashes. - To avoid the need to rename all cli functions, e.g. load_examples to - load-examples, this function is used to convert dashes back to - underscores. - - :param token_name: token name possibly containing dashes - :return: token name where dashes are replaced with underscores - """ - return token_name.replace("_", "-") diff --git a/superset/cli/main.py b/superset/cli/main.py index 42315fd9048..c5ea645cae6 100755 --- a/superset/cli/main.py +++ b/superset/cli/main.py @@ -22,18 +22,39 @@ from typing import Any import click from colorama import Fore, Style +from flask import current_app from flask.cli import FlaskGroup, with_appcontext -from superset import app, appbuilder, cli, security_manager -from superset.cli.lib import normalize_token +from superset import appbuilder, cli, security_manager from superset.extensions import db from superset.utils.decorators import transaction logger = logging.getLogger(__name__) +def normalize_token(token_name: str) -> str: + """ + As of click>=7, underscores in function names are replaced by dashes. + To avoid the need to rename all cli functions, e.g. load_examples to + load-examples, this function is used to convert dashes back to + underscores. + + :param token_name: token name possibly containing dashes + :return: token name where dashes are replaced with underscores + """ + return token_name.replace("_", "-") + + +def create_app() -> Any: + """Create app instance for CLI""" + from superset.app import create_app as create_superset_app + + return create_superset_app() + + @click.group( cls=FlaskGroup, + create_app=create_app, context_settings={"token_normalize_func": normalize_token}, ) @with_appcontext @@ -41,10 +62,6 @@ def superset() -> None: """\033[1;37mThe Apache Superset CLI\033[0m""" # NOTE: codes above are ANSI color codes for bold white in CLI header ^^^ - @app.shell_context_processor - def make_shell_context() -> dict[str, Any]: - return {"app": app, "db": db} - # add sub-commands for load, module_name, is_pkg in pkgutil.walk_packages( # noqa: B007 @@ -73,8 +90,14 @@ def init() -> None: @click.option("--verbose", "-v", is_flag=True, help="Show extra information") def version(verbose: bool) -> None: """Prints the current version number""" + print(Fore.BLUE + "-=" * 15) - print(Fore.YELLOW + "Superset " + Fore.CYAN + f"{app.config['VERSION_STRING']}") + print( + Fore.YELLOW + + "Superset " + + Fore.CYAN + + f"{current_app.config['VERSION_STRING']}" + ) print(Fore.BLUE + "-=" * 15) if verbose: print("[DB] : " + f"{db.engine}") diff --git a/superset/cli/reset.py b/superset/cli/reset.py index fd5e7260754..26ce032f7da 100644 --- a/superset/cli/reset.py +++ b/superset/cli/reset.py @@ -18,57 +18,66 @@ import sys import click +from flask import current_app from flask.cli import with_appcontext from werkzeug.security import check_password_hash -from superset.cli.lib import feature_flags -if feature_flags.get("ENABLE_FACTORY_RESET_COMMAND"): +@click.command() +@with_appcontext +@click.option("--username", prompt="Admin Username", help="Admin Username") +@click.option( + "--silent", + is_flag=True, + prompt=( + "Are you sure you want to reset Superset? " + "This action cannot be undone. Continue?" + ), + help="Confirmation flag", +) +@click.option( + "--exclude-users", + default=None, + help="Comma separated list of users to exclude from reset", +) +@click.option( + "--exclude-roles", + default=None, + help="Comma separated list of roles to exclude from reset", +) +def factory_reset( + username: str, silent: bool, exclude_users: str, exclude_roles: str +) -> None: + """Factory Reset Apache Superset""" - @click.command() - @with_appcontext - @click.option("--username", prompt="Admin Username", help="Admin Username") - @click.option( - "--silent", - is_flag=True, - prompt=( - "Are you sure you want to reset Superset? " - "This action cannot be undone. Continue?" - ), - help="Confirmation flag", - ) - @click.option( - "--exclude-users", - default=None, - help="Comma separated list of users to exclude from reset", - ) - @click.option( - "--exclude-roles", - default=None, - help="Comma separated list of roles to exclude from reset", - ) - def factory_reset( - username: str, silent: bool, exclude_users: str, exclude_roles: str - ) -> None: - """Factory Reset Apache Superset""" + # Check feature flag inside the command + if not current_app.config.get("FEATURE_FLAGS", {}).get( + "ENABLE_FACTORY_RESET_COMMAND" + ): + click.secho( + "Factory reset command is disabled. Enable " + "ENABLE_FACTORY_RESET_COMMAND feature flag.", + fg="red", + ) + sys.exit(1) - # pylint: disable=import-outside-toplevel - from superset import security_manager - from superset.commands.security.reset import ResetSupersetCommand + # pylint: disable=import-outside-toplevel + from superset import security_manager + from superset.commands.security.reset import ResetSupersetCommand - # Validate the user - password = click.prompt("Admin Password", hide_input=True) - user = security_manager.find_user(username) - if not user or not check_password_hash(user.password, password): - click.secho("Invalid credentials", fg="red") - sys.exit(1) - if not any(role.name == "Admin" for role in user.roles): - click.secho("Permission Denied", fg="red") - sys.exit(1) + # Validate the user + password = click.prompt("Admin Password", hide_input=True) + user = security_manager.find_user(username) + if not user or not check_password_hash(user.password, password): + click.secho("Invalid credentials", fg="red") + sys.exit(1) + if not any(role.name == "Admin" for role in user.roles): + click.secho("Permission Denied", fg="red") + sys.exit(1) - try: - ResetSupersetCommand(silent, user, exclude_users, exclude_roles).run() - click.secho("Factory reset complete", fg="green") - except Exception as ex: # pylint: disable=broad-except - click.secho(f"Factory reset failed: {ex}", fg="red") - sys.exit(1) + try: + ResetSupersetCommand(silent, user, exclude_users, exclude_roles).run() + click.secho("Factory reset complete", fg="green") + except Exception as ex: # pylint: disable=broad-except + click.secho(f"Factory reset failed: {ex}", fg="red") + sys.exit(1) diff --git a/superset/cli/test.py b/superset/cli/test.py index e8cc0925e03..d48a772a377 100755 --- a/superset/cli/test.py +++ b/superset/cli/test.py @@ -18,10 +18,11 @@ import logging import click from colorama import Fore +from flask import current_app from flask.cli import with_appcontext import superset.utils.database as database_utils -from superset import app, security_manager +from superset import security_manager from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ def load_test_users() -> None: """ print(Fore.GREEN + "Loading a set of users for unit tests") - if app.config["TESTING"]: + if current_app.config["TESTING"]: sm = security_manager examples_db = database_utils.get_example_database() diff --git a/superset/commands/dashboard/filter_state/get.py b/superset/commands/dashboard/filter_state/get.py index 29104b5ee2d..1fee32e1b01 100644 --- a/superset/commands/dashboard/filter_state/get.py +++ b/superset/commands/dashboard/filter_state/get.py @@ -28,8 +28,9 @@ from superset.temporary_cache.utils import cache_key class GetFilterStateCommand(GetTemporaryCacheCommand): def __init__(self, cmd_params: CommandParameters) -> None: super().__init__(cmd_params) - config = app.config["FILTER_STATE_CACHE_CONFIG"] - self._refresh_timeout = config.get("REFRESH_TIMEOUT_ON_RETRIEVAL") + self._refresh_timeout = app.config["FILTER_STATE_CACHE_CONFIG"].get( + "REFRESH_TIMEOUT_ON_RETRIEVAL" + ) def get(self, cmd_params: CommandParameters) -> Optional[str]: resource_id = cmd_params.resource_id diff --git a/superset/commands/dashboard/update.py b/superset/commands/dashboard/update.py index 0f67bc5f0ef..79710ca8df7 100644 --- a/superset/commands/dashboard/update.py +++ b/superset/commands/dashboard/update.py @@ -19,10 +19,11 @@ import textwrap from functools import partial from typing import Any, Optional +from flask import current_app from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError -from superset import app, db, security_manager +from superset import db, security_manager from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.dashboard.exceptions import ( DashboardColorsConfigUpdateFailedError, @@ -175,7 +176,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): to=email, subject=f"[Report: {report.name}] Deactivated", html_content=html_content, - config=app.config, + config=current_app.config, ) def deactivate_reports(reports_list: list[ReportSchedule]) -> None: diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 9425dc5ea94..90c4e3bd323 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -18,7 +18,7 @@ import logging from functools import partial from typing import Any, Optional -from flask import current_app +from flask import current_app as app from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -48,7 +48,7 @@ from superset.models.core import Database from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) -stats_logger = current_app.config["STATS_LOGGER"] +stats_logger = app.config["STATS_LOGGER"] class CreateDatabaseCommand(BaseCommand): diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py index e06e9a0be75..55b903744a6 100644 --- a/superset/commands/database/importers/v1/utils.py +++ b/superset/commands/database/importers/v1/utils.py @@ -18,7 +18,9 @@ import logging from typing import Any -from superset import app, db, security_manager +from flask import current_app as app + +from superset import db, security_manager from superset.commands.database.utils import add_permissions from superset.commands.exceptions import ImportFailedError from superset.databases.ssh_tunnel.models import SSHTunnel diff --git a/superset/commands/database/sync_permissions.py b/superset/commands/database/sync_permissions.py index bceacff89b1..a61d3ad80d2 100644 --- a/superset/commands/database/sync_permissions.py +++ b/superset/commands/database/sync_permissions.py @@ -20,9 +20,9 @@ import logging from functools import partial from typing import Iterable -from flask import current_app, g +from flask import current_app as app, g -from superset import app, security_manager +from superset import security_manager from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( DatabaseConnectionFailedError, @@ -320,7 +320,7 @@ def sync_database_permissions_task( """ Celery task that triggers the SyncPermissionsCommand in async mode. """ - with current_app.test_request_context(): + with app.test_request_context(): try: user = security_manager.get_user_by_username(username) if not user: diff --git a/superset/commands/database/validate_sql.py b/superset/commands/database/validate_sql.py index c13d4d5d338..c252cf6fdb8 100644 --- a/superset/commands/database/validate_sql.py +++ b/superset/commands/database/validate_sql.py @@ -18,7 +18,7 @@ import logging import re from typing import Any, Optional -from flask import current_app +from flask import current_app as app from flask_babel import gettext as __ from superset.commands.base import BaseCommand @@ -63,7 +63,7 @@ class ValidateSQLCommand(BaseCommand): catalog = self._properties.get("catalog") schema = self._properties.get("schema") try: - timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"] + timeout = app.config["SQLLAB_VALIDATION_TIMEOUT"] timeout_msg = f"The query exceeded the {timeout} seconds timeout." with utils.timeout(seconds=timeout, error_message=timeout_msg): errors = self._validator.validate(sql, catalog, schema, self._model) @@ -94,7 +94,7 @@ class ValidateSQLCommand(BaseCommand): raise DatabaseNotFoundError() spec = self._model.db_engine_spec - validators_by_engine = current_app.config["SQL_VALIDATORS_BY_ENGINE"] + validators_by_engine = app.config["SQL_VALIDATORS_BY_ENGINE"] if not validators_by_engine or spec.engine not in validators_by_engine: raise NoValidatorConfigFoundError( SupersetError( diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index b03f49c4e3f..f96717c0bbe 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -21,7 +21,7 @@ from typing import Any from urllib import request import pandas as pd -from flask import current_app +from flask import current_app as app from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.sql.visitors import VisitableType @@ -88,7 +88,7 @@ def validate_data_uri(data_uri: str) -> None: :param data_uri: :return: """ - allowed_urls = current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] + allowed_urls = app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] for allowed_url in allowed_urls: try: match = re.match(allowed_url, data_uri) @@ -218,7 +218,7 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None: df[column_name] = pd.to_datetime(df[column_name]) # reuse session when loading data if possible, to make import atomic - if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): + if database.sqlalchemy_uri == app.config.get("SQLALCHEMY_DATABASE_URI"): logger.info("Loading data inside the import transaction") connection = db.session.connection() df.to_sql( diff --git a/superset/commands/distributed_lock/base.py b/superset/commands/distributed_lock/base.py index 8fdb402a221..03fb9f2a6b4 100644 --- a/superset/commands/distributed_lock/base.py +++ b/superset/commands/distributed_lock/base.py @@ -19,14 +19,14 @@ import logging import uuid from typing import Any, Union -from flask import current_app +from flask import current_app as app from superset.commands.base import BaseCommand from superset.distributed_lock.utils import get_key from superset.key_value.types import JsonKeyValueCodec, KeyValueResource logger = logging.getLogger(__name__) -stats_logger = current_app.config["STATS_LOGGER"] +stats_logger = app.config["STATS_LOGGER"] class BaseDistributedLockCommand(BaseCommand): diff --git a/superset/commands/distributed_lock/create.py b/superset/commands/distributed_lock/create.py index c654089336a..2ac443df57f 100644 --- a/superset/commands/distributed_lock/create.py +++ b/superset/commands/distributed_lock/create.py @@ -19,7 +19,7 @@ import logging from datetime import datetime, timedelta from functools import partial -from flask import current_app +from flask import current_app as app from sqlalchemy.exc import SQLAlchemyError from superset.commands.distributed_lock.base import BaseDistributedLockCommand @@ -33,7 +33,7 @@ from superset.key_value.types import KeyValueResource from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) -stats_logger = current_app.config["STATS_LOGGER"] +stats_logger = app.config["STATS_LOGGER"] class CreateDistributedLock(BaseDistributedLockCommand): diff --git a/superset/commands/distributed_lock/delete.py b/superset/commands/distributed_lock/delete.py index cd279dbe240..2f4b6490100 100644 --- a/superset/commands/distributed_lock/delete.py +++ b/superset/commands/distributed_lock/delete.py @@ -18,7 +18,7 @@ import logging from functools import partial -from flask import current_app +from flask import current_app as app from sqlalchemy.exc import SQLAlchemyError from superset.commands.distributed_lock.base import BaseDistributedLockCommand @@ -28,7 +28,7 @@ from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) -stats_logger = current_app.config["STATS_LOGGER"] +stats_logger = app.config["STATS_LOGGER"] class DeleteDistributedLock(BaseDistributedLockCommand): diff --git a/superset/commands/distributed_lock/get.py b/superset/commands/distributed_lock/get.py index 56245641093..30115f54226 100644 --- a/superset/commands/distributed_lock/get.py +++ b/superset/commands/distributed_lock/get.py @@ -20,14 +20,14 @@ from __future__ import annotations import logging from typing import cast -from flask import current_app +from flask import current_app as app from superset.commands.distributed_lock.base import BaseDistributedLockCommand from superset.daos.key_value import KeyValueDAO from superset.distributed_lock.types import LockValue logger = logging.getLogger(__name__) -stats_logger = current_app.config["STATS_LOGGER"] +stats_logger = app.config["STATS_LOGGER"] class GetDistributedLock(BaseDistributedLockCommand): diff --git a/superset/commands/explore/form_data/get.py b/superset/commands/explore/form_data/get.py index 0153888d4e3..003c1dd74b4 100644 --- a/superset/commands/explore/form_data/get.py +++ b/superset/commands/explore/form_data/get.py @@ -35,8 +35,9 @@ logger = logging.getLogger(__name__) class GetFormDataCommand(BaseCommand, ABC): def __init__(self, cmd_params: CommandParameters) -> None: self._cmd_params = cmd_params - config = app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] - self._refresh_timeout = config.get("REFRESH_TIMEOUT_ON_RETRIEVAL") + self._refresh_timeout = app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"].get( + "REFRESH_TIMEOUT_ON_RETRIEVAL" + ) def run(self) -> Optional[str]: try: diff --git a/superset/commands/report/alert.py b/superset/commands/report/alert.py index 458f78fd3c6..bd8d97d4491 100644 --- a/superset/commands/report/alert.py +++ b/superset/commands/report/alert.py @@ -25,9 +25,10 @@ from uuid import UUID import numpy as np import pandas as pd from celery.exceptions import SoftTimeLimitExceeded +from flask import current_app as app from flask_babel import lazy_gettext as _ -from superset import app, jinja_context, security_manager +from superset import jinja_context, security_manager from superset.commands.base import BaseCommand from superset.commands.report.exceptions import ( AlertQueryError, diff --git a/superset/commands/report/base.py b/superset/commands/report/base.py index 199f985d0d0..18ea6ad1fd9 100644 --- a/superset/commands/report/base.py +++ b/superset/commands/report/base.py @@ -18,7 +18,7 @@ import logging from typing import Any from croniter import croniter -from flask import current_app +from flask import current_app as app from marshmallow import ValidationError from superset.commands.base import BaseCommand @@ -97,7 +97,7 @@ class BaseReportScheduleCommand(BaseCommand): if report_type == ReportScheduleType.ALERT else "REPORT_MINIMUM_INTERVAL" ) - minimum_interval = current_app.config.get(config_key, 0) + minimum_interval = app.config.get(config_key, 0) if callable(minimum_interval): minimum_interval = minimum_interval() diff --git a/superset/commands/report/execute.py b/superset/commands/report/execute.py index 2e94b8c4b80..1b19738bab9 100644 --- a/superset/commands/report/execute.py +++ b/superset/commands/report/execute.py @@ -21,8 +21,9 @@ from uuid import UUID import pandas as pd from celery.exceptions import SoftTimeLimitExceeded +from flask import current_app as app -from superset import app, db, security_manager +from superset import db, security_manager from superset.commands.base import BaseCommand from superset.commands.dashboard.permalink.create import CreateDashboardPermalinkCommand from superset.commands.exceptions import CommandException, UpdateFailedError @@ -310,6 +311,7 @@ class BaseReportState: Get chart or dashboard screenshots :raises: ReportScheduleScreenshotFailedError """ + _, username = get_executor( executors=app.config["ALERT_REPORTS_EXECUTORS"], model=self._report_schedule, @@ -407,6 +409,7 @@ class BaseReportState: """ Return data as a Pandas dataframe, to embed in notifications as a table. """ + url = self._get_url(result_format=ChartDataResultFormat.JSON) _, username = get_executor( executors=app.config["ALERT_REPORTS_EXECUTORS"], @@ -880,6 +883,7 @@ class AsyncExecuteReportScheduleCommand(BaseCommand): self.validate() if not self._model: raise ReportScheduleExecuteUnexpectedError() + _, username = get_executor( executors=app.config["ALERT_REPORTS_EXECUTORS"], model=self._model, diff --git a/superset/commands/sql_lab/estimate.py b/superset/commands/sql_lab/estimate.py index d3198815662..03aec8ade21 100644 --- a/superset/commands/sql_lab/estimate.py +++ b/superset/commands/sql_lab/estimate.py @@ -19,9 +19,10 @@ from __future__ import annotations import logging from typing import Any, TypedDict +from flask import current_app as app from flask_babel import gettext as __ -from superset import app, db +from superset import db from superset.commands.base import BaseCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException, SupersetTimeoutException @@ -29,10 +30,6 @@ from superset.jinja_context import get_template_processor from superset.models.core import Database from superset.utils import core as utils -config = app.config -SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"] -stats_logger = config["STATS_LOGGER"] - logger = logging.getLogger(__name__) @@ -81,7 +78,7 @@ class QueryEstimationCommand(BaseCommand): template_processor = get_template_processor(self._database) sql = template_processor.process_template(sql, **self._template_params) - timeout = SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT + timeout = app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"] timeout_msg = f"The estimation exceeded the {timeout} seconds timeout." try: with utils.timeout(seconds=timeout, error_message=timeout_msg): @@ -100,7 +97,7 @@ class QueryEstimationCommand(BaseCommand): "The query estimation was killed after %(sqllab_timeout)s " "seconds. It might be too complex, or the database might be " "under heavy load.", - sqllab_timeout=SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT, + sqllab_timeout=app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"], ), error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR, level=ErrorLevel.ERROR, diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index 34c4cc4d7a9..e488bf0a3ae 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -20,9 +20,10 @@ import logging from typing import Any, cast, TypedDict import pandas as pd +from flask import current_app as app from flask_babel import gettext as __ -from superset import app, db, results_backend, results_backend_use_msgpack +from superset import db, results_backend, results_backend_use_msgpack from superset.commands.base import BaseCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException, SupersetSecurityException @@ -32,8 +33,6 @@ from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils, csv from superset.views.utils import _deserialize_results_payload -config = app.config - logger = logging.getLogger(__name__) @@ -132,8 +131,8 @@ class SqlResultExportCommand(BaseCommand): )[:limit] # Manual encoding using the specified encoding (default to utf-8 if not set) - csv_string = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"]) - csv_data = csv_string.encode(config["CSV_EXPORT"].get("encoding", "utf-8")) + csv_string = csv.df_to_escaped_csv(df, index=False, **app.config["CSV_EXPORT"]) + csv_data = csv_string.encode(app.config["CSV_EXPORT"].get("encoding", "utf-8")) return { "query": self._query, diff --git a/superset/commands/sql_lab/results.py b/superset/commands/sql_lab/results.py index 83c8aa8f6a5..cf89c3d237d 100644 --- a/superset/commands/sql_lab/results.py +++ b/superset/commands/sql_lab/results.py @@ -19,9 +19,10 @@ from __future__ import annotations import logging from typing import Any, cast +from flask import current_app as app from flask_babel import gettext as __ -from superset import app, db, results_backend, results_backend_use_msgpack +from superset import db, results_backend, results_backend_use_msgpack from superset.commands.base import BaseCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SerializationError, SupersetErrorException @@ -31,10 +32,6 @@ from superset.utils import core as utils from superset.utils.dates import now_as_float from superset.views.utils import _deserialize_results_payload -config = app.config -SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"] -stats_logger = config["STATS_LOGGER"] - logger = logging.getLogger(__name__) @@ -64,7 +61,7 @@ class SqlExecutionResultsCommand(BaseCommand): read_from_results_backend_start = now_as_float() self._blob = results_backend.get(self._key) - stats_logger.timing( + app.config["STATS_LOGGER"].timing( "sqllab.query.results_backend_read", now_as_float() - read_from_results_backend_start, ) diff --git a/superset/commands/theme/seed.py b/superset/commands/theme/seed.py index 056fc78cca6..3aeeb0e4142 100644 --- a/superset/commands/theme/seed.py +++ b/superset/commands/theme/seed.py @@ -17,7 +17,7 @@ import logging from typing import Any -from flask import current_app +from flask import current_app as app from superset.commands.base import BaseCommand from superset.daos.theme import ThemeDAO @@ -36,9 +36,9 @@ class SeedSystemThemesCommand(BaseCommand): """Seed system themes defined in application configuration.""" themes_to_seed = [] - if theme_default := current_app.config.get("THEME_DEFAULT"): + if theme_default := app.config.get("THEME_DEFAULT"): themes_to_seed.append(("THEME_DEFAULT", theme_default)) - if theme_dark := current_app.config.get("THEME_DARK"): + if theme_dark := app.config.get("THEME_DARK"): themes_to_seed.append(("THEME_DARK", theme_dark)) for theme_name, theme_config in themes_to_seed: @@ -62,7 +62,7 @@ class SeedSystemThemesCommand(BaseCommand): f"Copied at startup from theme UUID {original_uuid} " f"based on config reference" ) - logger.info( + logger.debug( "Copied theme definition from UUID %s for system theme %s", original_uuid, theme_name, @@ -93,7 +93,7 @@ class SeedSystemThemesCommand(BaseCommand): if existing_theme: existing_theme.json_data = json_data - logger.info(f"Updated system theme: {theme_name}") + logger.debug(f"Updated system theme: {theme_name}") else: new_theme = Theme( theme_name=theme_name, @@ -101,7 +101,7 @@ class SeedSystemThemesCommand(BaseCommand): is_system=True, ) db.session.add(new_theme) - logger.info(f"Created system theme: {theme_name}") + logger.debug(f"Created system theme: {theme_name}") def validate(self) -> None: """Validate that the command can be executed.""" diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 9e61de6e1aa..694973a47ac 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -21,7 +21,6 @@ from typing import Any, Callable, TYPE_CHECKING from flask_babel import _ -from superset import app from superset.common.chart_data import ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.connectors.sqla.models import BaseDatasource @@ -38,8 +37,6 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject -config = app.config - def _get_datasource( query_context: QueryContext, query_obj: QueryObject diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 7a9e9e15028..9205ba418ef 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -18,7 +18,8 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING -from superset import app +from flask import current_app + from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject @@ -31,11 +32,9 @@ from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource -config = app.config - def create_query_object_factory() -> QueryObjectFactory: - return QueryObjectFactory(config, DatasourceDAO()) + return QueryObjectFactory(current_app.config, DatasourceDAO()) class QueryContextFactory: # pylint: disable=too-few-public-methods diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 49f8b35466a..209fae31ead 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -24,10 +24,10 @@ from typing import Any, cast, ClassVar, TYPE_CHECKING, TypedDict import numpy as np import pandas as pd +from flask import current_app from flask_babel import gettext as _ from pandas import DateOffset -from superset import app from superset.common.chart_data import ChartDataResultFormat from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results @@ -77,10 +77,7 @@ from superset.viz import viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject - from superset.stats_logger import BaseStatsLogger -config = app.config -stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) # Offset join column suffix used for joining offset results @@ -599,7 +596,7 @@ class QueryContextProcessor: # to the subquery so we prevent data inconsistency due to missing records # in the dataframes when performing the join if query_object.row_limit or query_object.row_offset: - query_object_clone_dct["row_limit"] = config["ROW_LIMIT"] + query_object_clone_dct["row_limit"] = current_app.config["ROW_LIMIT"] query_object_clone_dct["row_offset"] = 0 if isinstance(self._qc_datasource, Query): @@ -666,9 +663,9 @@ class QueryContextProcessor: :param time_grain: The time grain used to calculate the temporal join key. :param join_keys: The keys to join on. """ - join_column_producer = config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get( - time_grain - ) + join_column_producer = current_app.config[ + "TIME_GRAIN_JOIN_COLUMN_PRODUCERS" + ].get(time_grain) if join_column_producer and not time_grain: raise QueryObjectValidationError( @@ -775,11 +772,11 @@ class QueryContextProcessor: result = None if self._query_context.result_format == ChartDataResultFormat.CSV: result = csv.df_to_escaped_csv( - df, index=include_index, **config["CSV_EXPORT"] + df, index=include_index, **current_app.config["CSV_EXPORT"] ) elif self._query_context.result_format == ChartDataResultFormat.XLSX: excel.apply_column_types(df, coltypes) - result = excel.df_to_excel(df, **config["EXCEL_EXPORT"]) + result = excel.df_to_excel(df, **current_app.config["EXCEL_EXPORT"]) return result or "" return df.to_dict(orient="records") @@ -870,12 +867,12 @@ class QueryContextProcessor: if cache_timeout_rv := self._query_context.get_cache_timeout(): return cache_timeout_rv if ( - data_cache_timeout := config["DATA_CACHE_CONFIG"].get( + data_cache_timeout := current_app.config["DATA_CACHE_CONFIG"].get( "CACHE_DEFAULT_TIMEOUT" ) ) is not None: return data_cache_timeout - return config["CACHE_DEFAULT_TIMEOUT"] + return current_app.config["CACHE_DEFAULT_TIMEOUT"] def cache_key(self, **extra: Any) -> str: """ diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 8a488371119..8b101499255 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -19,10 +19,10 @@ from __future__ import annotations import logging from typing import Any +from flask import current_app from flask_caching import Cache from pandas import DataFrame -from superset import app from superset.common.db_query_status import QueryStatus from superset.constants import CacheRegion from superset.exceptions import CacheLoadError @@ -33,8 +33,6 @@ from superset.superset_typing import Column from superset.utils.cache import set_and_log_cache from superset.utils.core import error_msg_from_exception, get_stacktrace -config = app.config -stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) _cache: dict[CacheRegion, Cache] = { @@ -48,6 +46,10 @@ class QueryCacheManager: Class for manage query-cache getting and setting """ + @property + def stats_logger(self) -> BaseStatsLogger: + return current_app.config["STATS_LOGGER"] + # pylint: disable=too-many-instance-attributes,too-many-arguments def __init__( self, @@ -108,9 +110,11 @@ class QueryCacheManager: self.annotation_data = {} if annotation_data is None else annotation_data if self.status != QueryStatus.FAILED: - stats_logger.incr("loaded_from_source") + current_app.config["STATS_LOGGER"].incr("loaded_from_source") if not force_query: - stats_logger.incr("loaded_from_source_without_force") + current_app.config["STATS_LOGGER"].incr( + "loaded_from_source_without_force" + ) self.is_loaded = True value = { @@ -154,7 +158,7 @@ class QueryCacheManager: if cache_value := _cache[region].get(key): logger.debug("Cache key: %s", key) - stats_logger.incr("loading_from_cache") + current_app.config["STATS_LOGGER"].incr("loading_from_cache") try: query_cache.df = cache_value["df"] query_cache.query = cache_value["query"] @@ -176,7 +180,7 @@ class QueryCacheManager: cache_value["dttm"] if cache_value is not None else None ) query_cache.cache_value = cache_value - stats_logger.incr("loaded_from_cache") + current_app.config["STATS_LOGGER"].incr("loaded_from_cache") except KeyError as ex: logger.exception(ex) logger.error( diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py index 48f692b9f2e..ca89cc47393 100644 --- a/superset/common/utils/time_range_utils.py +++ b/superset/common/utils/time_range_utils.py @@ -19,7 +19,8 @@ from __future__ import annotations from datetime import datetime from typing import Any, cast -from superset import app +from flask import current_app + from superset.common.query_object import QueryObject from superset.utils.core import FilterOperator from superset.utils.date_parser import get_since_until @@ -32,10 +33,10 @@ def get_since_until_from_time_range( ) -> tuple[datetime | None, datetime | None]: return get_since_until( relative_start=(extras or {}).get( - "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"] + "relative_start", current_app.config["DEFAULT_RELATIVE_START_TIME"] ), relative_end=(extras or {}).get( - "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"] + "relative_end", current_app.config["DEFAULT_RELATIVE_END_TIME"] ), time_range=time_range, time_shift=time_shift, diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c6f30c5a7fc..d6fe46ba5f9 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -28,6 +28,7 @@ from typing import Any, Callable, cast, Optional, Union import pandas as pd import sqlalchemy as sa +from flask import current_app from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ @@ -67,7 +68,7 @@ from sqlalchemy.sql.expression import Label from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy.types import JSON -from superset import app, db, is_feature_enabled, security_manager +from superset import db, is_feature_enabled, security_manager from superset.commands.dataset.exceptions import DatasetNotFoundError from superset.common.db_query_status import QueryStatus from superset.connectors.sqla.utils import ( @@ -111,10 +112,9 @@ from superset.superset_typing import ( from superset.utils import core as utils, json from superset.utils.backports import StrEnum -config = app.config +config = current_app.config # Backward compatibility for tests metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] VIRTUAL_TABLE_ALIAS = "virtual_table" # a non-exhaustive set of additive metrics @@ -1310,7 +1310,7 @@ class SqlaTable( @property def health_check_message(self) -> str | None: - check = config["DATASET_HEALTH_CHECK"] + check = current_app.config["DATASET_HEALTH_CHECK"] return check(self) if check else None @property @@ -1721,7 +1721,7 @@ class SqlaTable( self.add_missing_metrics(metrics) # Apply config supplied mutations. - config["SQLA_TABLE_MUTATOR"](self) + current_app.config["SQLA_TABLE_MUTATOR"](self) db.session.merge(self) return results diff --git a/superset/databases/api.py b/superset/databases/api.py index c9b882dd6b9..c94060b2838 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -25,13 +25,20 @@ from typing import Any, cast from zipfile import is_zipfile, ZipFile from deprecation import deprecated -from flask import make_response, render_template, request, Response, send_file +from flask import ( + current_app as app, + make_response, + render_template, + request, + Response, + send_file, +) from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from marshmallow import ValidationError from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError -from superset import app, event_logger +from superset import event_logger from superset.commands.database.create import CreateDatabaseCommand from superset.commands.database.delete import DeleteDatabaseCommand from superset.commands.database.exceptions import ( diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 420a55fd263..321eb621005 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -23,7 +23,7 @@ from sqlalchemy.orm import Query from sqlalchemy.sql.expression import cast from sqlalchemy.sql.sqltypes import JSON -from superset import app, security_manager +from superset import security_manager from superset.models.core import Database from superset.views.base import BaseFilter @@ -91,7 +91,7 @@ class DatabaseUploadEnabledFilter(BaseFilter): # pylint: disable=too-few-public if hasattr(g, "user"): allowed_schemas = [ - app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](database, g.user) + current_app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](database, g.user) for database in datasource_access_databases ] diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 5c1450cec09..4a0e624b521 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -18,7 +18,6 @@ from typing import Any import sqlalchemy as sa -from flask import current_app from flask_appbuilder import Model from sqlalchemy.orm import backref, relationship from sqlalchemy.types import Text @@ -32,8 +31,6 @@ from superset.models.helpers import ( ImportExportMixin, ) -app_config = current_app.config - class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model): """ diff --git a/superset/datasource/api.py b/superset/datasource/api.py index 31e8c503ee0..077f3c97314 100644 --- a/superset/datasource/api.py +++ b/superset/datasource/api.py @@ -16,9 +16,10 @@ # under the License. import logging +from flask import current_app as app from flask_appbuilder.api import expose, protect, safe -from superset import app, event_logger +from superset import event_logger from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.exceptions import SupersetSecurityException diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index a68e8847bec..40bb985f404 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -38,10 +38,11 @@ from pathlib import Path from typing import Any, Optional import sqlalchemy.dialects +from flask import current_app as app from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.exc import NoSuchModuleError -from superset import app, feature_flag_manager +from superset import feature_flag_manager from superset.db_engine_specs.base import BaseEngineSpec logger = logging.getLogger(__name__) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 069f5299a92..cccecff6539 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -42,7 +42,7 @@ import requests from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from deprecation import deprecated -from flask import current_app, g, url_for +from flask import current_app as app, g, url_for from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema @@ -459,7 +459,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def is_oauth2_enabled(cls) -> bool: return ( cls.supports_oauth2 - and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"] + and cls.engine_name in app.config["DATABASE_OAUTH2_CLIENTS"] ) @classmethod @@ -512,12 +512,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Build the DB engine spec level OAuth2 client config. """ - oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"] + oauth2_config = app.config["DATABASE_OAUTH2_CLIENTS"] if cls.engine_name not in oauth2_config: return None db_engine_spec_config = oauth2_config[cls.engine_name] - redirect_uri = current_app.config.get( + redirect_uri = app.config.get( "DATABASE_OAUTH2_REDIRECT_URI", url_for("DatabaseRestApi.oauth2", _external=True), ) @@ -573,7 +573,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Exchange authorization code for refresh/access tokens. """ - timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] req_body = { "code": code, @@ -595,7 +595,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Refresh an access token that has expired. """ - timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] req_body = { "client_id": config["id"], @@ -871,7 +871,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ret_list = [] time_grains = builtin_time_grains.copy() - time_grains.update(current_app.config["TIME_GRAIN_ADDONS"]) + time_grains.update(app.config["TIME_GRAIN_ADDONS"]) for duration, func in cls.get_time_grain_expressions().items(): if duration in time_grains: name = time_grains[duration] @@ -950,9 +950,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ # TODO: use @memoize decorator or similar to avoid recomputation on every call time_grain_expressions = cls._time_grain_expressions.copy() - grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] + grain_addon_expressions = app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {})) - denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"] + denylist: list[str] = app.config["TIME_GRAIN_DENYLIST"] for key in denylist: time_grain_expressions.pop(key, None) @@ -2235,7 +2235,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param sqlalchemy_uri: """ - if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]: + if db_engine_uri_validator := app.config["DB_SQLA_URI_VALIDATOR"]: db_engine_uri_validator(sqlalchemy_uri) if existing_disallowed := cls.disallow_uri_query_params.get( diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index 91e7ccf4f51..ada2f9f5e40 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -22,7 +22,7 @@ from datetime import datetime from typing import Any, cast, TYPE_CHECKING from urllib import parse -from flask import current_app +from flask import current_app as app from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.validate import Range @@ -247,7 +247,7 @@ try: ) set_setting( "product_name", - f"superset/{current_app.config.get('VERSION_STRING', 'dev')}", + f"superset/{app.config.get('VERSION_STRING', 'dev')}", ) except ImportError: # ClickHouse Connect not installed, do nothing pass diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index c3b9db1b900..5cf11febca5 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -24,13 +24,13 @@ from typing import Any, TYPE_CHECKING, TypedDict from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin +from flask import current_app as app from flask_babel import gettext as __ from marshmallow import fields, Schema from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL -from superset.config import VERSION_STRING from superset.constants import TimeGrain from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec @@ -252,7 +252,8 @@ class DuckDBEngineSpec(DuckDBParametersMixin, BaseEngineSpec): delim = " " if custom_user_agent else "" user_agent = get_user_agent(database, source) user_agent = user_agent.replace(" ", "-").lower() - user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}" + version_string = app.config["VERSION_STRING"] + user_agent = f"{user_agent}/{version_string}{delim}{custom_user_agent}" config.setdefault("custom_user_agent", user_agent) return extra diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 34ac2aa0ac8..e33438338d2 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -29,7 +29,7 @@ import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -from flask import current_app, g +from flask import current_app as app, g from sqlalchemy import Column, text, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector @@ -66,7 +66,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: import boto3 # pylint: disable=all from boto3.s3.transfer import TransferConfig # pylint: disable=all - bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] + bucket_path = app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] if not bucket_path: logger.info("No upload bucket specified") @@ -224,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec): ) with tempfile.NamedTemporaryFile( - dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet" + dir=app.config["UPLOAD_FOLDER"], suffix=".parquet" ) as file: pq.write_table(pa.Table.from_pandas(df), where=file.name) @@ -243,9 +243,9 @@ class HiveEngineSpec(PrestoEngineSpec): ), location=upload_to_s3( filename=file.name, - upload_prefix=current_app.config[ - "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC" - ](database, g.user, table.schema), + upload_prefix=app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]( + database, g.user, table.schema + ), table=table, ), ) @@ -401,12 +401,13 @@ class HiveEngineSpec(PrestoEngineSpec): last_log_line = len(log_lines) if needs_commit: db.session.commit() # pylint: disable=consider-using-transaction - if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"): + if sleep_interval := app.config.get("HIVE_POLL_INTERVAL"): logger.warning( - "HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" # noqa: E501 + "HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. " + "Please use DB_POLL_INTERVAL_SECONDS instead" ) else: - sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get( + sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get( cls.engine, 5 ) time.sleep(sleep_interval) diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 456f81c790b..4438f38e8c3 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -24,7 +24,7 @@ from datetime import datetime from typing import Any, Optional, TYPE_CHECKING import requests -from flask import current_app +from flask import current_app as app from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector @@ -155,7 +155,7 @@ class ImpalaEngineSpec(BaseEngineSpec): if needs_commit: db.session.commit() # pylint: disable=consider-using-transaction - sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get( + sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get( cls.engine, 5 ) time.sleep(sleep_interval) diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index 75889c706c2..44c30006299 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -28,10 +28,9 @@ with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be inst # Ensure pyocient inherits Superset's logging level import geojson import pyocient + from flask import current_app as app from shapely import wkt - from superset import app - superset_log_level = app.config["LOG_LEVEL"] pyocient.logger.setLevel(superset_log_level) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index ec0ddc2df02..92786aba7b7 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -30,7 +30,7 @@ from typing import Any, cast, Optional, TYPE_CHECKING from urllib import parse import pandas as pd -from flask import current_app +from flask import current_app as app from flask_babel import gettext as __, lazy_gettext as _ from packaging.version import Version from sqlalchemy import Column, literal_column, types @@ -1318,7 +1318,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): query_id = query.id poll_interval = query.database.connect_args.get( - "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"] + "poll_interval", app.config["PRESTO_POLL_INTERVAL"] ) logger.info("Query %i: Polling the cursor for progress", query_id) polled = cursor.poll() diff --git a/superset/db_engine_specs/singlestore.py b/superset/db_engine_specs/singlestore.py index 9dba3e452a5..27df97ff9b9 100644 --- a/superset/db_engine_specs/singlestore.py +++ b/superset/db_engine_specs/singlestore.py @@ -21,7 +21,7 @@ from datetime import datetime from typing import Any, Optional from urllib import parse -from flask import current_app +from flask import current_app as app from sqlalchemy import types from sqlalchemy.engine import URL @@ -498,8 +498,8 @@ class SingleStoreSpec(BasicParametersMixin, BaseEngineSpec): "conn_attrs", { "_connector_name": "SingleStore Superset Database Engine", - "_connector_version": current_app.config.get("VERSION_STRING", "dev"), - "_product_version": current_app.config.get("VERSION_STRING", "dev"), + "_connector_version": app.config.get("VERSION_STRING", "dev"), + "_product_version": app.config.get("VERSION_STRING", "dev"), }, ) return uri, connect_args diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 5b064c22c77..5b9df25f111 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -27,7 +27,7 @@ from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from flask import current_app +from flask import current_app as app from flask_babel import gettext as __ from marshmallow import fields, Schema from sqlalchemy import types @@ -411,9 +411,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): ) connect_args["private_key"] = pkb else: - allowed_extra_auths = current_app.config[ - "ALLOWED_EXTRA_AUTHENTICATIONS" - ].get("snowflake", {}) + allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get( + "snowflake", {} + ) if auth_method in allowed_extra_auths: snowflake_auth = allowed_extra_auths.get(auth_method) else: diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 7f5debeed74..e5d0789ff2c 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -23,7 +23,7 @@ import time from typing import Any, TYPE_CHECKING import requests -from flask import copy_current_request_context, ctx, current_app, Flask, g +from flask import copy_current_request_context, ctx, current_app as app, Flask, g from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -249,7 +249,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): args=( execute_result, execute_event, - current_app._get_current_object(), # pylint: disable=protected-access + app._get_current_object(), # pylint: disable=protected-access g._get_current_object(), # pylint: disable=protected-access ), ) @@ -352,9 +352,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): elif auth_method == "jwt": from trino.auth import JWTAuthentication as trino_auth # noqa else: - allowed_extra_auths = current_app.config[ - "ALLOWED_EXTRA_AUTHENTICATIONS" - ].get("trino", {}) + allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get( + "trino", {} + ) if auth_method in allowed_extra_auths: trino_auth = allowed_extra_auths.get(auth_method) else: diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index ddd5ea50aa2..79210493f71 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -19,10 +19,11 @@ import textwrap from typing import Union import pandas as pd +from flask import current_app from sqlalchemy import DateTime, inspect, String from sqlalchemy.sql import column -from superset import app, db, security_manager +from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -173,7 +174,7 @@ def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]: "limit": "25", "granularity_sqla": "ds", "groupby": [], - "row_limit": app.config["ROW_LIMIT"], + "row_limit": current_app.config["ROW_LIMIT"], "time_range": "100 years ago : now", "viz_type": "table", "markup_type": "markdown", @@ -584,8 +585,7 @@ def create_dashboard(slices: list[Slice]) -> Dashboard: } }""" ) - # pylint: disable=line-too-long - pos = json.loads( # noqa: TID251 + pos = json.loads( textwrap.dedent( """\ { @@ -859,7 +859,6 @@ def create_dashboard(slices: list[Slice]) -> Dashboard: """ # noqa: E501 ) ) - # pylint: enable=line-too-long # dashboard v2 doesn't allow add markup slice dash.slices = [slc for slc in slices if slc.viz_type != "markup"] update_slice_ids(pos) diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py index 3be6589979e..c10cd0f36bd 100644 --- a/superset/examples/helpers.py +++ b/superset/examples/helpers.py @@ -48,8 +48,9 @@ from typing import Any from urllib.error import HTTPError import pandas as pd +from flask import current_app -from superset import app, db +from superset import db from superset.connectors.sqla.models import SqlaTable from superset.models.slice import Slice from superset.utils import json @@ -80,7 +81,7 @@ def get_table_connector_registry() -> Any: def get_examples_folder() -> str: """Return local path to the examples folder (when vendored).""" - return os.path.join(app.config["BASE_DIR"], "examples") + return os.path.join(current_app.config["BASE_DIR"], "examples") def update_slice_ids(pos: dict[Any, Any]) -> list[Slice]: diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index a86ecdbdc54..3e6130903f7 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -18,9 +18,10 @@ import logging from typing import Optional import pandas as pd +from flask import current_app from sqlalchemy import BigInteger, Date, DateTime, inspect, String -from superset import app, db +from superset import db from superset.models.slice import Slice from superset.sql.parse import Table from superset.utils.core import DatasourceType @@ -115,7 +116,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals slice_data = { "metrics": ["count"], "granularity_sqla": col.column_name, - "row_limit": app.config["ROW_LIMIT"], + "row_limit": current_app.config["ROW_LIMIT"], "since": "2015", "until": "2016", "viz_type": "cal_heatmap", diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index c92cd617618..395eb99c2ce 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -17,10 +17,11 @@ import logging import pandas as pd +from flask import current_app from sqlalchemy import DateTime, inspect, String import superset.utils.database as database_utils -from superset import app, db +from superset import db from superset.models.slice import Slice from superset.sql.parse import Table from superset.utils.core import DatasourceType @@ -81,7 +82,7 @@ def load_random_time_series_data( slice_data = { "granularity_sqla": "ds", - "row_limit": app.config["ROW_LIMIT"], + "row_limit": current_app.config["ROW_LIMIT"], "since": "2019-01-01", "until": "2019-02-01", "metrics": ["count"], diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index e1e9363fbda..57413ca28ea 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -18,11 +18,12 @@ import logging import os import pandas as pd +from flask import current_app from sqlalchemy import DateTime, inspect, String from sqlalchemy.sql import column import superset.utils.database -from superset import app, db +from superset import db from superset.connectors.sqla.models import BaseDatasource, SqlMetric from superset.examples.helpers import ( get_examples_folder, @@ -155,7 +156,7 @@ def create_slices(tbl: BaseDatasource) -> list[Slice]: "limit": "25", "granularity_sqla": "year", "groupby": [], - "row_limit": app.config["ROW_LIMIT"], + "row_limit": current_app.config["ROW_LIMIT"], "since": "2014-01-01", "until": "2014-01-02", "time_range": "2014-01-01 : 2014-01-02", diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 83085e7d131..9ac70a4ad1d 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -23,6 +23,7 @@ import sys from typing import Any, Callable, TYPE_CHECKING import wtforms_json +from colorama import Fore, Style from deprecation import deprecated from flask import abort, Flask, redirect, request, session, url_for from flask_appbuilder import expose, IndexView @@ -201,6 +202,12 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods from superset.views.users_list import UsersListView set_app_error_handlers(self.superset_app) + self.register_request_handlers() + + # Register health blueprint + from superset.views.health import health_blueprint + + self.superset_app.register_blueprint(health_blueprint) # # Setup API views @@ -552,6 +559,54 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods if self.config["SESSION_SERVER_SIDE"]: Session(self.superset_app) + def register_request_handlers(self) -> None: + """Register app-level request handlers""" + from flask import Response + + @self.superset_app.after_request + def apply_http_headers(response: Response) -> Response: + """Applies the configuration's http headers to all responses""" + # HTTP_HEADERS is deprecated, this provides backwards compatibility + response.headers.extend( + { + **self.superset_app.config["OVERRIDE_HTTP_HEADERS"], + **self.superset_app.config["HTTP_HEADERS"], + } + ) + + for k, v in self.superset_app.config["DEFAULT_HTTP_HEADERS"].items(): + if k not in response.headers: + response.headers[k] = v + return response + + @self.superset_app.context_processor + def get_common_bootstrap_data() -> dict[str, Any]: + # Import here to avoid circular imports + from superset.utils import json + from superset.views.base import common_bootstrap_payload + + def serialize_bootstrap_data() -> str: + return json.dumps( + {"common": common_bootstrap_payload()}, + default=json.pessimistic_json_iso_dttm_ser, + ) + + return {"bootstrap_data": serialize_bootstrap_data} + + def check_and_warn_database_connection(self) -> None: + """Check database connection and warn if unavailable""" + try: + with self.superset_app.app_context(): + # Simple connection test + db.engine.execute("SELECT 1") + except Exception: + db_uri = self.config.get("SQLALCHEMY_DATABASE_URI", "") + safe_uri = make_url_safe(db_uri) if db_uri else "Not configured" + print( + f"{Fore.RED}ERROR: Cannot connect to database {safe_uri}\n" + f"NOTE: Most CLI commands require a database{Style.RESET_ALL}" + ) + def init_app(self) -> None: """ Main entry point which will delegate to other methods in @@ -567,6 +622,10 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.configure_feature_flags() self.configure_db_encrypt() self.setup_db() + + # Check database connection and warn if unavailable + self.check_and_warn_database_connection() + self.configure_celery() self.enable_profiling() self.setup_event_logger() @@ -599,7 +658,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods set_isolation_level_to = "READ COMMITTED" if set_isolation_level_to: - logger.info( + logger.debug( "Setting database isolation level to %s", set_isolation_level_to, ) @@ -777,6 +836,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods async_query_manager_factory.init_app(self.superset_app) def register_blueprints(self) -> None: + # Register custom blueprints from config for bp in self.config["BLUEPRINTS"]: try: logger.info("Registering blueprint: %s", bp.name) diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index e2a09731a12..2bbba9fdec8 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -18,11 +18,11 @@ import copy import logging from typing import Any +from flask import current_app from sqlalchemy import and_, Column, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session -from superset import conf from superset.constants import TimeGrain from superset.migrations.shared.utils import paginated_update, try_load_json from superset.utils import json @@ -97,7 +97,9 @@ class MigrateViz: def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None: """Adds a temporal filter.""" granularity_sqla = rv_data.pop("granularity_sqla", None) - time_range = rv_data.pop("time_range", None) or conf.get("DEFAULT_TIME_FILTER") + time_range = rv_data.pop("time_range", None) or current_app.config.get( + "DEFAULT_TIME_FILTER" + ) if not granularity_sqla: return diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py index c2d05093b9b..011ade44908 100644 --- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -28,7 +28,6 @@ import sqlalchemy as sa from alembic import op from sqlalchemy_utils import UUIDType -from superset import app from superset.extensions import encrypted_field_factory from superset.migrations.shared.utils import create_table @@ -36,8 +35,6 @@ from superset.migrations.shared.utils import create_table revision = "f3c2d8ec8595" down_revision = "4ce1d9b25135" -app_config = app.config - def upgrade(): create_table( diff --git a/superset/models/core.py b/superset/models/core.py index 0a5bb515e07..f20806457fd 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -36,7 +36,7 @@ import numpy import pandas as pd import sqlalchemy as sqla import sshtunnel -from flask import g +from flask import current_app as app, g, has_app_context from flask_appbuilder import Model from marshmallow.exceptions import ValidationError from sqlalchemy import ( @@ -61,7 +61,7 @@ from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select -from superset import app, db, db_engine_specs, is_feature_enabled +from superset import db, db_engine_specs, is_feature_enabled from superset.commands.database.exceptions import DatabaseInvalidError from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK from superset.databases.utils import make_url_safe @@ -90,10 +90,6 @@ from superset.utils.oauth2 import ( OAuth2ClientConfigSchema, ) -config = app.config -custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] -stats_logger = config["STATS_LOGGER"] -log_query = config["QUERY_LOGGER"] metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) @@ -101,8 +97,6 @@ if TYPE_CHECKING: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.sql_lab import Query -DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] - class KeyValue(Model): # pylint: disable=too-few-public-methods """Used for any type of key-value store""" @@ -395,6 +389,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def set_sqlalchemy_uri(self, uri: str) -> None: conn = make_url_safe(uri.strip()) + custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password @@ -464,7 +459,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ssh_context, ) - engine_context_manager = config["ENGINE_CONTEXT_MANAGER"] + engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] with engine_context_manager(self, catalog, schema): with check_for_oauth2(self): yield self._get_sqla_engine( @@ -533,7 +528,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable self.update_params_from_encrypted_extra(engine_kwargs) - if DB_CONNECTION_MUTATOR: + if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806 source = source or get_query_source_from_request() sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR( @@ -654,8 +649,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable on the group of queries as a whole. Here the called passes the context as to whether the SQL is split or already. """ # noqa: E501 - sql_mutator = config["SQL_QUERY_MUTATOR"] - if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]): + sql_mutator = app.config["SQL_QUERY_MUTATOR"] + if sql_mutator and (is_split == app.config["MUTATE_AFTER_SPLIT"]): return sql_mutator( sql_, security_manager=security_manager, @@ -674,6 +669,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url + log_query = app.config["QUERY_LOGGER"] + def _log_query(sql: str) -> None: if log_query: log_query( @@ -1050,7 +1047,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable allowed_databases = literal_eval(allowed_databases) if hasattr(g, "user"): - extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( + extra_allowed_databases = app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( self, g.user ) allowed_databases += extra_allowed_databases @@ -1064,8 +1061,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable # if the URI is invalid, ignore and return a placeholder url # (so users see 500 less often) return "dialect://invalid_uri" - if custom_password_store: - conn = conn.set(password=custom_password_store(conn)) + if has_app_context(): + if custom_password_store := app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]: + conn = conn.set(password=custom_password_store(conn)) + else: + conn = conn.set(password=self.password) else: conn = conn.set(password=self.password) return str(conn) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 37b67a8eb2e..3083cf282c0 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -22,6 +22,7 @@ from collections import defaultdict, deque from typing import Any, Callable import sqlalchemy as sqla +from flask import current_app as app from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_appbuilder.security.sqla.models import User @@ -41,7 +42,7 @@ from sqlalchemy.orm import relationship, subqueryload from sqlalchemy.orm.mapper import Mapper from sqlalchemy.sql.elements import BinaryExpression -from superset import app, db, is_feature_enabled, security_manager +from superset import db, is_feature_enabled, security_manager from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.daos.datasource import DatasourceDAO from superset.models.helpers import AuditMixinNullable, ImportExportMixin @@ -53,12 +54,11 @@ from superset.thumbnails.digest import get_dashboard_digest from superset.utils import core as utils, json metadata = Model.metadata # pylint: disable=no-member -config = app.config logger = logging.getLogger(__name__) def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard) -> None: - dashboard_id = config["DASHBOARD_TEMPLATE_ID"] + dashboard_id = app.config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 400255dd07b..494aaa5e931 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -35,7 +35,7 @@ import pandas as pd import pytz import sqlalchemy as sa import yaml -from flask import g +from flask import current_app as app, g from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.mixins import AuditMixin @@ -52,7 +52,7 @@ from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType -from superset import app, db, is_feature_enabled +from superset import db, is_feature_enabled from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus from superset.common.utils.time_range_utils import get_since_until_from_time_range @@ -96,13 +96,10 @@ if TYPE_CHECKING: from superset.db_engine_specs import BaseEngineSpec from superset.models.core import Database - -config = app.config logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" SERIES_LIMIT_SUBQ_ALIAS = "series_limit" -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] def validate_adhoc_subquery( @@ -1870,6 +1867,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods is_list_target=is_list_target, db_engine_spec=db_engine_spec, ) + + # Get ADVANCED_DATA_TYPES from config when needed + ADVANCED_DATA_TYPES = app.config.get("ADVANCED_DATA_TYPES", {}) # noqa: N806 + if ( col_advanced_data_type != "" and feature_flag_manager.is_feature_enabled( diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 9552d856760..d1603af44b5 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -25,7 +25,7 @@ from datetime import datetime from typing import Any, Optional, TYPE_CHECKING import sqlalchemy as sqla -from flask import current_app +from flask import current_app as app from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_babel import gettext as __ @@ -334,7 +334,7 @@ class Query( Transform tracking url at run time because the exact URL may depend on query properties such as execution and finish time. """ - transform = current_app.config.get("TRACKING_URL_TRANSFORMER") + transform = app.config.get("TRACKING_URL_TRANSFORMER") url = self.tracking_url_raw if url and transform: sig = inspect.signature(transform) diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py index e5bf7354767..22a0d38bdd6 100644 --- a/superset/reports/notifications/email.py +++ b/superset/reports/notifications/email.py @@ -22,10 +22,11 @@ from email.utils import make_msgid, parseaddr from typing import Any, Optional import nh3 +from flask import current_app from flask_babel import gettext as __ from pytz import timezone -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.exceptions import SupersetErrorsException from superset.reports.models import ReportRecipientType from superset.reports.notifications.base import BaseNotification @@ -94,7 +95,7 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met @staticmethod def _get_smtp_domain() -> str: - return parseaddr(app.config["SMTP_MAIL_FROM"])[1].split("@")[1] + return parseaddr(current_app.config["SMTP_MAIL_FROM"])[1].split("@")[1] def _error_template(self, text: str) -> str: call_to_action = self._get_call_to_action() @@ -202,7 +203,7 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met def _get_subject(self) -> str: return __( "%(prefix)s %(title)s", - prefix=app.config["EMAIL_REPORTS_SUBJECT_PREFIX"], + prefix=current_app.config["EMAIL_REPORTS_SUBJECT_PREFIX"], title=self._name, ) @@ -214,7 +215,7 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met return self.now.strftime(name) def _get_call_to_action(self) -> str: - return __(app.config["EMAIL_REPORTS_CTA"]) + return __(current_app.config["EMAIL_REPORTS_CTA"]) def _get_to(self) -> str: return json.loads(self._recipient.recipient_config_json)["target"] @@ -240,7 +241,7 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met to, subject, content.body, - app.config, + current_app.config, files=[], data=content.data, pdf=content.pdf, diff --git a/superset/security/manager.py b/superset/security/manager.py index 884152b195b..5bb2ac30323 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -95,6 +95,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + +def get_conf() -> Any: + return current_app.config + + DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P\d+)\)$") @@ -652,7 +657,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :returns: The access URL """ - return current_app.config.get("PERMISSION_INSTRUCTIONS_LINK") + return get_conf().get("PERMISSION_INSTRUCTIONS_LINK") def get_datasource_access_error_object( # pylint: disable=invalid-name self, datasource: "BaseDatasource" @@ -713,7 +718,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :returns: The access URL """ - return current_app.config.get("PERMISSION_INSTRUCTIONS_LINK") + return get_conf().get("PERMISSION_INSTRUCTIONS_LINK") def get_user_datasources(self) -> list["BaseDatasource"]: """ @@ -1114,9 +1119,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.set_role("sql_lab", self._is_sql_lab_pvm, pvms) # Configure public role - if current_app.config["PUBLIC_ROLE_LIKE"]: + if get_conf()["PUBLIC_ROLE_LIKE"]: self.copy_role( - current_app.config["PUBLIC_ROLE_LIKE"], + get_conf()["PUBLIC_ROLE_LIKE"], self.auth_role_public, merge=True, ) @@ -2474,7 +2479,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if not user: user = g.user if user.is_anonymous: - public_role = current_app.config.get("AUTH_ROLE_PUBLIC") + public_role = get_conf().get("AUTH_ROLE_PUBLIC") return [self.get_public_role()] if public_role else [] return super().get_user_roles(user) @@ -2591,7 +2596,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods @staticmethod def _get_guest_token_jwt_audience() -> str: - audience = current_app.config["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() + audience = get_conf()["GUEST_TOKEN_JWT_AUDIENCE"] or get_url_host() if callable(audience): audience = audience() return audience @@ -2620,9 +2625,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods resources: GuestTokenResources, rls: list[GuestTokenRlsRule], ) -> bytes: - secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] - algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] - exp_seconds = current_app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] + secret = get_conf()["GUEST_TOKEN_JWT_SECRET"] + algo = get_conf()["GUEST_TOKEN_JWT_ALGO"] + exp_seconds = get_conf()["GUEST_TOKEN_JWT_EXP_SECONDS"] audience = self._get_guest_token_jwt_audience() # calculate expiration time now = self._get_current_epoch_time() @@ -2649,7 +2654,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :return: A guest user object """ raw_token = req.headers.get( - current_app.config["GUEST_TOKEN_HEADER_NAME"] + get_conf()["GUEST_TOKEN_HEADER_NAME"] ) or req.form.get("guest_token") if raw_token is None: return None @@ -2675,7 +2680,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: return self.guest_user_cls( token=token, - roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], + roles=[self.find_role(get_conf()["GUEST_ROLE_NAME"])], ) def parse_jwt_guest_token(self, raw_token: str) -> dict[str, Any]: @@ -2684,8 +2689,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param raw_token: the token gotten from the request :return: the same token that was passed in, tested but unchanged """ - secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] - algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] + secret = get_conf()["GUEST_TOKEN_JWT_SECRET"] + algo = get_conf()["GUEST_TOKEN_JWT_ALGO"] audience = self._get_guest_token_jwt_audience() return self.pyjwt_for_guest_token.decode( raw_token, secret, algorithms=[algo], audience=audience @@ -2780,7 +2785,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :returns: Whether the current user is an admin user """ - return current_app.config["AUTH_ROLE_ADMIN"] in [ + return get_conf()["AUTH_ROLE_ADMIN"] in [ role.name for role in self.get_user_roles() ] diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 57b6410ece0..c15e9242469 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -27,11 +27,10 @@ from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union import backoff import msgpack from celery.exceptions import SoftTimeLimitExceeded -from flask import current_app +from flask import current_app as app, has_app_context from flask_babel import gettext as __ from superset import ( - app, db, is_feature_enabled, results_backend, @@ -72,13 +71,6 @@ from superset.utils.rls import apply_rls if TYPE_CHECKING: from superset.models.core import Database -config = app.config -stats_logger = config["STATS_LOGGER"] -SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] -SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 -SQL_MAX_ROW = config["SQL_MAX_ROW"] -SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"] -log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) BYTES_IN_MB = 1024 * 1024 @@ -127,12 +119,13 @@ def handle_query_error( db.session.commit() payload.update({"status": query.status, "error": msg, "errors": errors_payload}) - if troubleshooting_link := config["TROUBLESHOOTING_LINK"]: + if troubleshooting_link := app.config["TROUBLESHOOTING_LINK"]: payload["link"] = troubleshooting_link return payload def get_query_backoff_handler(details: dict[Any, Any]) -> None: + stats_logger = app.config["STATS_LOGGER"] query_id = details["kwargs"]["query_id"] logger.error( "Query with id `%s` could not be retrieved", str(query_id), exc_info=True @@ -144,6 +137,7 @@ def get_query_backoff_handler(details: dict[Any, Any]) -> None: def get_query_giveup_handler(_: Any) -> None: + stats_logger = app.config["STATS_LOGGER"] stats_logger.incr("error_failed_at_getting_orm_query") @@ -163,10 +157,14 @@ def get_query(query_id: int) -> Query: raise SqlLabException("Failed at getting query") from ex +# Default timeouts from config.py: +# SQLLAB_TIMEOUT = 30 seconds +# SQLLAB_ASYNC_TIME_LIMIT_SEC = 6 hours +# SQLLAB_HARD_TIMEOUT = SQLLAB_ASYNC_TIME_LIMIT_SEC + 60 @celery_app.task( name="sql_lab.get_sql_results", - time_limit=SQLLAB_HARD_TIMEOUT, - soft_time_limit=SQLLAB_TIMEOUT, + time_limit=21660, # 6 hours + 60 seconds + soft_time_limit=21600, # 6 hours ) def get_sql_results( # pylint: disable=too-many-arguments query_id: int, @@ -179,7 +177,7 @@ def get_sql_results( # pylint: disable=too-many-arguments log_params: Optional[dict[str, Any]] = None, ) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" - with current_app.test_request_context(): + with app.test_request_context(): with override_user(security_manager.find_user(username)): try: return execute_sql_statements( @@ -193,6 +191,7 @@ def get_sql_results( # pylint: disable=too-many-arguments ) except Exception as ex: # pylint: disable=broad-except logger.debug("Query %d: %s", query_id, ex) + stats_logger = app.config["STATS_LOGGER"] stats_logger.incr("error_sqllab_unhandled") query = get_query(query_id=query_id) return handle_query_error(ex, query) @@ -225,14 +224,17 @@ def apply_limit(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None: """ Apply limit to the SQL statement. """ + sqllab_ctas_no_limit = app.config["SQLLAB_CTAS_NO_LIMIT"] + sql_max_row = app.config["SQL_MAX_ROW"] + # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true if parsed_statement.is_mutating() or ( - query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT + query.select_as_cta_used and sqllab_ctas_no_limit ): return - if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): - query.limit = SQL_MAX_ROW + if sql_max_row and (not query.limit or query.limit > sql_max_row): + query.limit = sql_max_row if query.limit: parsed_statement.set_limit_value( @@ -252,6 +254,7 @@ def execute_query( # pylint: disable=too-many-statements, too-many-locals # no db_engine_spec = database.db_engine_spec try: + log_query = app.config["QUERY_LOGGER"] if log_query: log_query( query.database.sqlalchemy_uri, @@ -267,6 +270,7 @@ def execute_query( # pylint: disable=too-many-statements, too-many-locals # no database=database, object_ref=__name__, ): + stats_logger = app.config["STATS_LOGGER"] with stats_timing("sqllab.query.time_executing_query", stats_logger): db_engine_spec.execute_with_cursor(cursor, query.executed_sql, query) @@ -293,7 +297,7 @@ def execute_query( # pylint: disable=too-many-statements, too-many-locals # no message=__( "The query was killed after %(sqllab_timeout)s seconds. It might " "be too complex, or the database might be under heavy load.", - sqllab_timeout=SQLLAB_TIMEOUT, + sqllab_timeout=app.config["SQLLAB_ASYNC_TIME_LIMIT_SEC"], ), error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR, level=ErrorLevel.ERROR, @@ -338,9 +342,14 @@ def _serialize_and_expand_data( expanded_columns: list[Any] if use_msgpack: - with stats_timing( - "sqllab.query.results_backend_pa_serialization", stats_logger - ): + if has_app_context(): + stats_logger = app.config["STATS_LOGGER"] + with stats_timing( + "sqllab.query.results_backend_pa_serialization", stats_logger + ): + data = write_ipc_buffer(result_set.pa_table).to_pybytes() + else: + # No app context, skip stats timing data = write_ipc_buffer(result_set.pa_table).to_pybytes() # expand when loading data from results backend @@ -373,6 +382,7 @@ def execute_sql_statements( # noqa: C901 """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries + stats_logger = app.config["STATS_LOGGER"] stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id=query_id) @@ -391,7 +401,7 @@ def execute_sql_statements( # noqa: C901 parsed_script = SQLScript(rendered_query, engine=db_engine_spec.engine) - disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get( + disallowed_functions = app.config["DISALLOWED_SQL_FUNCTIONS"].get( db_engine_spec.engine, set(), ) @@ -538,6 +548,7 @@ def execute_sql_statements( # noqa: C901 logger.info( "Query %s: Storing results in results backend, key: %s", str(query_id), key ) + stats_logger = app.config["STATS_LOGGER"] with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing( "sqllab.query.results_backend_write_serialization", stats_logger @@ -547,7 +558,7 @@ def execute_sql_statements( # noqa: C901 ) # Check the size of the serialized payload - if sql_lab_payload_max_mb := config.get("SQLLAB_PAYLOAD_MAX_MB"): + if sql_lab_payload_max_mb := app.config.get("SQLLAB_PAYLOAD_MAX_MB"): serialized_payload_size = sys.getsizeof(serialized_payload) max_bytes = sql_lab_payload_max_mb * BYTES_IN_MB @@ -563,7 +574,7 @@ def execute_sql_statements( # noqa: C901 cache_timeout = database.cache_timeout if cache_timeout is None: - cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] + cache_timeout = app.config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) logger.debug( @@ -596,7 +607,7 @@ def execute_sql_statements( # noqa: C901 } ) # Check the size of the serialized payload (opt-in logic for return_results) - if sql_lab_payload_max_mb := config.get("SQLLAB_PAYLOAD_MAX_MB"): + if sql_lab_payload_max_mb := app.config.get("SQLLAB_PAYLOAD_MAX_MB"): serialized_payload = _serialize_payload( payload, cast(bool, results_backend_use_msgpack) ) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index e247a0322a4..bd3fea8d380 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -22,7 +22,6 @@ import time from contextlib import closing from typing import Any, cast -from superset import app from superset.models.core import Database from superset.sql.parse import SQLScript, SQLStatement from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation @@ -30,7 +29,6 @@ from superset.utils.core import QuerySource MAX_ERROR_ROWS = 10 -config = app.config logger = logging.getLogger(__name__) diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index cdb2436e583..a8ffd92f24f 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -18,13 +18,13 @@ import logging from typing import Any, cast, Optional from urllib import parse -from flask import request, Response +from flask import current_app as app, request, Response from flask_appbuilder import permission_name from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from marshmallow import ValidationError -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.commands.sql_lab.estimate import QueryEstimationCommand from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand from superset.commands.sql_lab.export import SqlResultExportCommand @@ -65,7 +65,6 @@ from superset.utils import core as utils, json from superset.views.base import CsvResponse, generate_download_headers, json_success from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics -config = app.config logger = logging.getLogger(__name__) @@ -430,7 +429,7 @@ class SqlLabRestApi(BaseSupersetApi): ) execution_context_convertor = ExecutionContextConvertor() execution_context_convertor.set_max_row_in_display( - int(config.get("DISPLAY_MAX_ROW")) + int(app.config.get("DISPLAY_MAX_ROW")) ) return ExecuteSqlCommand( execution_context, @@ -440,7 +439,7 @@ class SqlLabRestApi(BaseSupersetApi): SqlQueryRenderImpl(get_template_processor), sql_json_executor, execution_context_convertor, - config["SQLLAB_CTAS_NO_LIMIT"], + app.config["SQLLAB_CTAS_NO_LIMIT"], log_params, ) @@ -455,7 +454,7 @@ class SqlLabRestApi(BaseSupersetApi): sql_json_executor = SynchronousSqlJsonExecutor( query_dao, get_sql_results, - config.get("SQLLAB_TIMEOUT"), + app.config.get("SQLLAB_TIMEOUT"), is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"), ) return sql_json_executor diff --git a/superset/tasks/cron_util.py b/superset/tasks/cron_util.py index 329937fb828..53fba16e12f 100644 --- a/superset/tasks/cron_util.py +++ b/superset/tasks/cron_util.py @@ -20,17 +20,16 @@ from collections.abc import Iterator from datetime import datetime, timedelta from croniter import croniter +from flask import current_app from pytz import timezone as pytz_timezone, UnknownTimeZoneError -from superset import app - logger = logging.getLogger(__name__) def cron_schedule_window( triggered_at: datetime, cron: str, timezone: str ) -> Iterator[datetime]: - window_size = app.config["ALERT_REPORTS_CRON_WINDOW_SIZE"] + window_size = current_app.config["ALERT_REPORTS_CRON_WINDOW_SIZE"] try: tz = pytz_timezone(timezone) except UnknownTimeZoneError: diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py index 3fec34d8163..c59b4f4dd8f 100644 --- a/superset/tasks/scheduler.py +++ b/superset/tasks/scheduler.py @@ -22,8 +22,9 @@ from typing import Any from celery import Task from celery.exceptions import SoftTimeLimitExceeded +from flask import current_app -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.commands.exceptions import CommandException from superset.commands.logs.prune import LogPruneCommand from superset.commands.report.exceptions import ReportScheduleUnexpectedError @@ -45,7 +46,7 @@ def scheduler() -> None: """ Celery beat main scheduler for reports """ - stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] + stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"] stats_logger.incr("reports.scheduler") if not is_feature_enabled("ALERT_REPORTS"): @@ -53,7 +54,7 @@ def scheduler() -> None: active_schedules = ReportScheduleDAO.find_active() triggered_at = ( datetime.fromisoformat(scheduler.request.expires) - - app.config["CELERY_BEAT_SCHEDULER_EXPIRES"] + - current_app.config["CELERY_BEAT_SCHEDULER_EXPIRES"] if scheduler.request.expires else datetime.now(tz=timezone.utc) ) @@ -65,22 +66,22 @@ def scheduler() -> None: async_options = {"eta": schedule} if ( active_schedule.working_timeout is not None - and app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] + and current_app.config["ALERT_REPORTS_WORKING_TIME_OUT_KILL"] ): async_options["time_limit"] = ( active_schedule.working_timeout - + app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"] + + current_app.config["ALERT_REPORTS_WORKING_TIME_OUT_LAG"] ) async_options["soft_time_limit"] = ( active_schedule.working_timeout - + app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"] + + current_app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"] ) execute.apply_async((active_schedule.id,), **async_options) @celery_app.task(name="reports.execute", bind=True) def execute(self: Task, report_schedule_id: int) -> None: - stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] + stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"] stats_logger.incr("reports.execute") task_id = None @@ -115,7 +116,7 @@ def execute(self: Task, report_schedule_id: int) -> None: @celery_app.task(name="reports.prune_log") def prune_log() -> None: - stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] + stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"] stats_logger.incr("reports.prune_log") try: @@ -130,7 +131,7 @@ def prune_log() -> None: def prune_query( self: Task, retention_period_days: int | None = None, **kwargs: Any ) -> None: - stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] + stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"] stats_logger.incr("prune_query") # TODO: Deprecated: Remove support for passing retention period via options in 6.0 @@ -153,7 +154,7 @@ def prune_query( def prune_logs( self: Task, retention_period_days: int | None = None, **kwargs: Any ) -> None: - stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] + stats_logger: BaseStatsLogger = current_app.config["STATS_LOGGER"] stats_logger.incr("prune_logs") # TODO: Deprecated: Remove support for passing retention period via options in 6.0 diff --git a/superset/thumbnails/digest.py b/superset/thumbnails/digest.py index 446d06b20d9..31a179fd93c 100644 --- a/superset/thumbnails/digest.py +++ b/superset/thumbnails/digest.py @@ -20,7 +20,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING -from flask import current_app +from flask import current_app as app from superset import security_manager from superset.tasks.exceptions import ExecutorNotFoundError @@ -91,17 +91,16 @@ def _adjust_string_with_rls( def get_dashboard_digest(dashboard: Dashboard) -> str | None: - config = current_app.config try: executor_type, executor = get_executor( - executors=config["THUMBNAIL_EXECUTORS"], + executors=app.config["THUMBNAIL_EXECUTORS"], model=dashboard, current_user=get_current_user(), ) except ExecutorNotFoundError: return None - if func := config["THUMBNAIL_DASHBOARD_DIGEST_FUNC"]: + if func := app.config["THUMBNAIL_DASHBOARD_DIGEST_FUNC"]: return func(dashboard, executor_type, executor) unique_string = ( @@ -118,17 +117,16 @@ def get_dashboard_digest(dashboard: Dashboard) -> str | None: def get_chart_digest(chart: Slice) -> str | None: - config = current_app.config try: executor_type, executor = get_executor( - executors=config["THUMBNAIL_EXECUTORS"], + executors=app.config["THUMBNAIL_EXECUTORS"], model=chart, current_user=get_current_user(), ) except ExecutorNotFoundError: return None - if func := config["THUMBNAIL_CHART_DIGEST_FUNC"]: + if func := app.config["THUMBNAIL_CHART_DIGEST_FUNC"]: return func(chart, executor_type, executor) unique_string = f"{chart.params or ''}.{executor}" diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 93072e6a534..f217b62a634 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -20,7 +20,7 @@ import inspect import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable from flask import current_app as app, request from flask_caching import Cache @@ -33,11 +33,6 @@ from superset.models.cache import CacheKey from superset.utils.hashing import md5_sha_from_dict from superset.utils.json import json_int_dttm_ser -if TYPE_CHECKING: - from superset.stats_logger import BaseStatsLogger - -config = app.config -stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) @@ -65,9 +60,10 @@ def set_and_log_cache( dttm = datetime.utcnow().isoformat().split(".")[0] value = {**cache_value, "dttm": dttm} cache_instance.set(cache_key, value, timeout=timeout) + stats_logger = app.config["STATS_LOGGER"] stats_logger.incr("set_cache_key") - if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]: + if datasource_uid and app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]: ck = CacheKey( cache_key=cache_key, cache_timeout=cache_timeout, @@ -149,7 +145,7 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def etag_cache( # noqa: C901 cache: Cache = cache_manager.cache, get_last_modified: Callable[..., datetime] | None = None, - max_age: int | float = app.config["CACHE_DEFAULT_TIMEOUT"], + max_age: int | float | None = None, raise_for_access: Callable[..., Any] | None = None, skip: Callable[..., bool] | None = None, ) -> Callable[..., Any]: @@ -167,6 +163,9 @@ def etag_cache( # noqa: C901 """ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: # noqa: C901 + # Compute the actual timeout to use + timeout = max_age or app.config["CACHE_DEFAULT_TIMEOUT"] + @wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Response: # noqa: C901 # Check if the user can access the resource @@ -231,7 +230,7 @@ def etag_cache( # noqa: C901 response.cache_control.public = True response.last_modified = content_changed_time - expiration = max_age or ONE_YEAR # max_age=0 also means far future + expiration = timeout or ONE_YEAR # max_age=0 also means far future response.expires = response.last_modified + timedelta( seconds=expiration ) @@ -239,7 +238,7 @@ def etag_cache( # noqa: C901 # if we have a cache, store the response from the request try: - cache.set(cache_key, response, timeout=max_age) + cache.set(cache_key, response, timeout=timeout) except Exception: # pylint: disable=broad-except if app.debug: raise @@ -248,9 +247,9 @@ def etag_cache( # noqa: C901 return response.make_conditional(request) wrapper.uncached = f # type: ignore - wrapper.cache_timeout = max_age # type: ignore + wrapper.cache_timeout = timeout # type: ignore wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access - make_name=None, timeout=max_age + make_name=None, timeout=timeout ) return wrapper diff --git a/superset/utils/core.py b/superset/utils/core.py index 0e0d2e5bd32..81a4a7078d4 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -67,7 +67,7 @@ import pandas as pd import sqlalchemy as sa from cryptography.hazmat.backends import default_backend from cryptography.x509 import Certificate, load_pem_x509_certificate -from flask import current_app, g, request +from flask import current_app as app, g, request from flask_appbuilder import SQLA from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __ @@ -1381,7 +1381,9 @@ def create_ssl_cert_file(certificate: str) -> str: :raises CertificateException: If certificate is not valid/unparseable """ filename = f"{md5_sha_from_str(certificate)}.crt" - cert_dir = current_app.config["SSL_CERT_PATH"] + # pylint: disable=import-outside-toplevel + + cert_dir = app.config["SSL_CERT_PATH"] path = cert_dir if cert_dir else tempfile.gettempdir() path = os.path.join(path, filename) if not os.path.exists(path): @@ -1428,7 +1430,9 @@ class DatasourceName(NamedTuple): def get_stacktrace() -> str | None: - if current_app.config["SHOW_STACKTRACE"]: + # pylint: disable=import-outside-toplevel + + if app.config["SHOW_STACKTRACE"]: return traceback.format_exc() return None @@ -1867,10 +1871,12 @@ def apply_max_row_limit( >>> apply_max_row_limit(0) # Zero returns default max limit 50000 """ + # pylint: disable=import-outside-toplevel + max_limit = ( - current_app.config["TABLE_VIZ_MAX_ROW_SERVER"] + app.config["TABLE_VIZ_MAX_ROW_SERVER"] if server_pagination - else current_app.config["SQL_MAX_ROW"] + else app.config["SQL_MAX_ROW"] ) if limit != 0: return min(max_limit, limit) @@ -1894,15 +1900,17 @@ def check_is_safe_zip(zip_file: ZipFile) -> None: :param zip_file: :return: """ + # pylint: disable=import-outside-toplevel + uncompress_size = 0 compress_size = 0 for zip_file_element in zip_file.infolist(): - if zip_file_element.file_size > current_app.config["ZIPPED_FILE_MAX_SIZE"]: + if zip_file_element.file_size > app.config["ZIPPED_FILE_MAX_SIZE"]: raise SupersetException("Found file with size above allowed threshold") uncompress_size += zip_file_element.file_size compress_size += zip_file_element.compress_size compress_ratio = uncompress_size / compress_size - if compress_ratio > current_app.config["ZIP_FILE_MAX_COMPRESS_RATIO"]: + if compress_ratio > app.config["ZIP_FILE_MAX_COMPRESS_RATIO"]: raise SupersetException("Zip compress ratio above allowed threshold") @@ -1939,8 +1947,10 @@ def get_query_source_from_request() -> QuerySource | None: def get_user_agent(database: Database, source: QuerySource | None) -> str: + # pylint: disable=import-outside-toplevel + source = source or get_query_source_from_request() - if user_agent_func := current_app.config["USER_AGENT_FUNC"]: + if user_agent_func := app.config["USER_AGENT_FUNC"]: return user_agent_func(database, source) return DEFAULT_USER_AGENT diff --git a/superset/utils/database.py b/superset/utils/database.py index 719e7f2d772..cc16eaacb01 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -19,7 +19,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING -from flask import current_app +from flask import current_app as app from superset.constants import EXAMPLES_DB_UUID @@ -64,11 +64,15 @@ def get_or_create_db( def get_example_database() -> Database: - return get_or_create_db("examples", current_app.config["SQLALCHEMY_EXAMPLES_URI"]) + # pylint: disable=import-outside-toplevel + + return get_or_create_db("examples", app.config["SQLALCHEMY_EXAMPLES_URI"]) def get_main_database() -> Database: - db_uri = current_app.config["SQLALCHEMY_DATABASE_URI"] + # pylint: disable=import-outside-toplevel + + db_uri = app.config["SQLALCHEMY_DATABASE_URI"] return get_or_create_db("main", db_uri) diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index d5b73b0447b..cb5711452ba 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -24,7 +24,7 @@ from functools import wraps from typing import Any, Callable, TYPE_CHECKING from uuid import UUID -from flask import current_app, g, Response +from flask import current_app as app, g, Response from sqlalchemy.exc import SQLAlchemyError from superset.utils import core as utils @@ -46,19 +46,15 @@ def statsd_gauge(metric_prefix: str | None = None) -> Callable[..., Any]: metric_prefix_ = metric_prefix or f.__name__ try: result = f(*args, **kwargs) - current_app.config["STATS_LOGGER"].gauge(f"{metric_prefix_}.ok", 1) + app.config["STATS_LOGGER"].gauge(f"{metric_prefix_}.ok", 1) return result except Exception as ex: if ( hasattr(ex, "status") and ex.status < 500 # pylint: disable=no-member ): - current_app.config["STATS_LOGGER"].gauge( - f"{metric_prefix_}.warning", 1 - ) + app.config["STATS_LOGGER"].gauge(f"{metric_prefix_}.warning", 1) else: - current_app.config["STATS_LOGGER"].gauge( - f"{metric_prefix_}.error", 1 - ) + app.config["STATS_LOGGER"].gauge(f"{metric_prefix_}.error", 1) raise return wrapped diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index 53401a70602..ab05e455a38 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -68,7 +68,7 @@ class EncryptedFieldFactory: def init_app(self, app: Flask) -> None: self._config = app.config - self._concrete_type_adapter = self._config[ # type: ignore + self._concrete_type_adapter = app.config[ "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER" ]() diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py index df3979b0a7c..9f9d2fe79b6 100644 --- a/superset/utils/machine_auth.py +++ b/superset/utils/machine_auth.py @@ -21,7 +21,7 @@ import logging from typing import Any, Callable, TYPE_CHECKING from urllib.parse import urlparse -from flask import current_app, Flask, request, Response, session +from flask import current_app as app, Flask, request, Response, session from flask_login import login_user from selenium.webdriver.remote.webdriver import WebDriver from werkzeug.http import parse_cookie @@ -86,7 +86,7 @@ class MachineAuthProvider: if self._auth_webdriver_func_override: return self._auth_webdriver_func_override(browser_context, user) - url = urlparse(current_app.config["WEBDRIVER_BASEURL"]) + url = urlparse(app.config["WEBDRIVER_BASEURL"]) # Setting cookies requires doing a request first page = browser_context.new_page() @@ -122,13 +122,13 @@ class MachineAuthProvider: @staticmethod def get_auth_cookies(user: User) -> dict[str, str]: # Login with the user specified to get the reports - with current_app.test_request_context("/login"): + with app.test_request_context("/login"): login_user(user) # A mock response object to get the cookie information from response = Response() # To ensure all `after_request` functions are called i.e Websockets JWT Auth - current_app.process_response(response) - current_app.session_interface.save_session(current_app, session, response) + app.process_response(response) + app.session_interface.save_session(app, session, response) cookies = {} diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index b93f89c8700..ebe1f4012eb 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -23,7 +23,7 @@ from typing import Any, Iterator, TYPE_CHECKING import backoff import jwt -from flask import current_app, url_for +from flask import current_app as app, url_for from marshmallow import EXCLUDE, fields, post_load, Schema, validate from superset import db @@ -128,8 +128,8 @@ def encode_oauth2_state(state: OAuth2State) -> str: } encoded_state = jwt.encode( payload=payload, - key=current_app.config["SECRET_KEY"], - algorithm=current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"], + key=app.config["SECRET_KEY"], + algorithm=app.config["DATABASE_OAUTH2_JWT_ALGORITHM"], ) # Google OAuth2 needs periods to be escaped. @@ -175,8 +175,8 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State: payload = jwt.decode( jwt=encoded_state, - key=current_app.config["SECRET_KEY"], - algorithms=[current_app.config["DATABASE_OAUTH2_JWT_ALGORITHM"]], + key=app.config["SECRET_KEY"], + algorithms=[app.config["DATABASE_OAUTH2_JWT_ALGORITHM"]], ) state = oauth2_state_schema.load(payload) diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index cf28dcf916c..23493471b84 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -23,9 +23,9 @@ from enum import Enum from io import BytesIO from typing import cast, TYPE_CHECKING, TypedDict -from flask import current_app +from flask import current_app as app -from superset import app, feature_flag_manager, thumbnail_cache +from superset import feature_flag_manager, thumbnail_cache from superset.extensions import event_logger from superset.utils.hashing import md5_sha_from_dict from superset.utils.urls import modify_url_query @@ -147,7 +147,10 @@ class ScreenshotCachePayload: class BaseScreenshot: - driver_type = current_app.config["WEBDRIVER_TYPE"] + @property + def driver_type(self) -> str: + return app.config["WEBDRIVER_TYPE"] + url: str digest: str | None screenshot: bytes | None diff --git a/superset/utils/slack.py b/superset/utils/slack.py index 34d48bef21b..e1557934a66 100644 --- a/superset/utils/slack.py +++ b/superset/utils/slack.py @@ -19,7 +19,7 @@ import logging from typing import Callable, Optional -from flask import current_app +from flask import current_app as app from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.http_retry.builtin_handlers import RateLimitErrorRetryHandler @@ -45,10 +45,10 @@ class SlackClientError(Exception): def get_slack_client() -> WebClient: - token: str = current_app.config["SLACK_API_TOKEN"] + token: str = app.config["SLACK_API_TOKEN"] if callable(token): token = token() - client = WebClient(token=token, proxy=current_app.config["SLACK_PROXY"]) + client = WebClient(token=token, proxy=app.config["SLACK_PROXY"]) rate_limit_handler = RateLimitErrorRetryHandler(max_retry_count=2) client.retry_handlers.append(rate_limit_handler) @@ -102,7 +102,7 @@ def get_channels_with_search( try: channels = get_channels( force=force, - cache_timeout=current_app.config["SLACK_CACHE_TIMEOUT"], + cache_timeout=app.config["SLACK_CACHE_TIMEOUT"], ) except (SlackClientError, SlackApiError) as ex: raise SupersetException(f"Failed to list channels: {ex}") from ex diff --git a/superset/utils/urls.py b/superset/utils/urls.py index 08b30dff68f..9e672eb9441 100644 --- a/superset/utils/urls.py +++ b/superset/utils/urls.py @@ -19,13 +19,13 @@ from contextlib import nullcontext from typing import Any from urllib.parse import urlparse -from flask import current_app, has_request_context, url_for +from flask import current_app as app, has_request_context, url_for def get_url_host(user_friendly: bool = False) -> str: if user_friendly: - return current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"] - return current_app.config["WEBDRIVER_BASEURL"] + return app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"] + return app.config["WEBDRIVER_BASEURL"] def headless_url(path: str, user_friendly: bool = False) -> str: @@ -36,7 +36,7 @@ def get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str: if has_request_context(): request_context = nullcontext else: - request_context = current_app.test_request_context + request_context = app.test_request_context with request_context(): return headless_url(url_for(view, **kwargs), user_friendly=user_friendly) diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 3216521bf10..6570e550f47 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -23,7 +23,7 @@ from enum import Enum from time import sleep from typing import TYPE_CHECKING -from flask import current_app +from flask import current_app as app from packaging import version from selenium import __version__ as selenium_version from selenium.common.exceptions import ( @@ -75,8 +75,8 @@ class WebDriverProxy(ABC): def __init__(self, driver_type: str, window: WindowSize | None = None): self._driver_type = driver_type self._window: WindowSize = window or (800, 600) - self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"] - self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"] + self._screenshot_locate_wait = app.config["SCREENSHOT_LOCATE_WAIT"] + self._screenshot_load_wait = app.config["SCREENSHOT_LOAD_WAIT"] @abstractmethod def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: @@ -140,11 +140,9 @@ class WebDriverPlaywright(WebDriverProxy): self, url: str, element_name: str, user: User ) -> bytes | None: with sync_playwright() as playwright: - browser_args = current_app.config["WEBDRIVER_OPTION_ARGS"] + browser_args = app.config["WEBDRIVER_OPTION_ARGS"] browser = playwright.chromium.launch(args=browser_args) - pixel_density = current_app.config["WEBDRIVER_WINDOW"].get( - "pixel_density", 1 - ) + pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1) context = browser.new_context( bypass_csp=True, viewport={ @@ -154,24 +152,24 @@ class WebDriverPlaywright(WebDriverProxy): device_scale_factor=pixel_density, ) context.set_default_timeout( - current_app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"] + app.config["SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT"] ) self.auth(user, context) page = context.new_page() try: page.goto( url, - wait_until=current_app.config["SCREENSHOT_PLAYWRIGHT_WAIT_EVENT"], + wait_until=app.config["SCREENSHOT_PLAYWRIGHT_WAIT_EVENT"], ) except PlaywrightTimeout: logger.exception( "Web event %s not detected. Page %s might not have been fully loaded", # noqa: E501 - current_app.config["SCREENSHOT_PLAYWRIGHT_WAIT_EVENT"], + app.config["SCREENSHOT_PLAYWRIGHT_WAIT_EVENT"], url, ) img: bytes | None = None - selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] + selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"] logger.debug("Sleeping for %i seconds", selenium_headstart) page.wait_for_timeout(selenium_headstart * 1000) element: Locator @@ -212,7 +210,7 @@ class WebDriverPlaywright(WebDriverProxy): ) raise - selenium_animation_wait = current_app.config[ + selenium_animation_wait = app.config[ "SCREENSHOT_SELENIUM_ANIMATION_WAIT" ] logger.debug( @@ -224,7 +222,7 @@ class WebDriverPlaywright(WebDriverProxy): url, user.username, ) - if current_app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: + if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: unexpected_errors = WebDriverPlaywright.find_unexpected_errors(page) if unexpected_errors: logger.warning( @@ -246,7 +244,7 @@ class WebDriverPlaywright(WebDriverProxy): class WebDriverSelenium(WebDriverProxy): def create(self) -> WebDriver: - pixel_density = current_app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1) + pixel_density = app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1) if self._driver_type == "firefox": driver_class: type[WebDriver] = firefox.webdriver.WebDriver service_class: type[Service] = firefox.service.Service @@ -268,11 +266,11 @@ class WebDriverSelenium(WebDriverProxy): ) # Prepare args for the webdriver init - for arg in list(current_app.config["WEBDRIVER_OPTION_ARGS"]): + for arg in list(app.config["WEBDRIVER_OPTION_ARGS"]): options.add_argument(arg) # Add additional configured webdriver options - webdriver_conf = dict(current_app.config["WEBDRIVER_CONFIGURATION"]) + webdriver_conf = dict(app.config["WEBDRIVER_CONFIGURATION"]) # Set the binary location if provided # We need to pop it from the dict due to selenium_version < 4.10.0 @@ -344,7 +342,7 @@ class WebDriverSelenium(WebDriverProxy): # wait for modal to show up modal = WebDriverWait( driver, - current_app.config["SCREENSHOT_WAIT_FOR_ERROR_MODAL_VISIBLE"], + app.config["SCREENSHOT_WAIT_FOR_ERROR_MODAL_VISIBLE"], ).until( EC.visibility_of_any_elements_located( (By.CLASS_NAME, "ant-modal-content") @@ -362,7 +360,7 @@ class WebDriverSelenium(WebDriverProxy): # wait until the modal becomes invisible WebDriverWait( driver, - current_app.config["SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE"], + app.config["SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE"], ).until(EC.invisibility_of_element(modal)) # Use HTML so that error messages are shown in the same style (color) @@ -388,7 +386,7 @@ class WebDriverSelenium(WebDriverProxy): driver.set_window_size(*self._window) driver.get(url) img: bytes | None = None - selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] + selenium_headstart = app.config["SCREENSHOT_SELENIUM_HEADSTART"] logger.debug("Sleeping for %i seconds", selenium_headstart) sleep(selenium_headstart) @@ -443,9 +441,7 @@ class WebDriverSelenium(WebDriverProxy): ) raise - selenium_animation_wait = current_app.config[ - "SCREENSHOT_SELENIUM_ANIMATION_WAIT" - ] + selenium_animation_wait = app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] logger.debug("Wait %i seconds for chart animation", selenium_animation_wait) sleep(selenium_animation_wait) logger.debug( @@ -454,7 +450,7 @@ class WebDriverSelenium(WebDriverProxy): user.username, ) - if current_app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: + if app.config["SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"]: unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver) if unexpected_errors: logger.warning( @@ -483,5 +479,5 @@ class WebDriverSelenium(WebDriverProxy): ) raise finally: - self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"]) + self.destroy(driver, app.config["SCREENSHOT_SELENIUM_RETRIES"]) return img diff --git a/superset/views/base.py b/superset/views/base.py index 540d7e9a2c8..2318efe1f90 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -26,6 +26,7 @@ from typing import Any, Callable from babel import Locale from flask import ( abort, + current_app as app, flash, g, get_flashed_messages, @@ -48,9 +49,7 @@ from sqlalchemy.orm import Query from wtforms.fields.core import Field, UnboundField from superset import ( - app as superset_app, appbuilder, - conf, db, get_feature_flags, is_feature_enabled, @@ -129,11 +128,10 @@ FRONTEND_CONF_KEYS = ( ) logger = logging.getLogger(__name__) -config = superset_app.config def get_error_msg() -> str: - if conf.get("SHOW_STACKTRACE"): + if app.config.get("SHOW_STACKTRACE"): error_msg = traceback.format_exc() else: error_msg = "FATAL ERROR \n" @@ -240,10 +238,10 @@ class BaseSupersetView(BaseView): def get_environment_tag() -> dict[str, Any]: # Whether flask is in debug mode (--debug) - debug = appbuilder.app.config["DEBUG"] + debug = app.config["DEBUG"] # Getting the configuration option for ENVIRONMENT_TAG_CONFIG - env_tag_config = appbuilder.app.config["ENVIRONMENT_TAG_CONFIG"] + env_tag_config = app.config["ENVIRONMENT_TAG_CONFIG"] # These are the predefined templates define in the config env_tag_templates = env_tag_config.get("values") @@ -268,32 +266,31 @@ def menu_data(user: User) -> dict[str, Any]: for lang in appbuilder.languages } - if callable(brand_text := appbuilder.app.config["LOGO_RIGHT_TEXT"]): + if callable(brand_text := app.config["LOGO_RIGHT_TEXT"]): brand_text = brand_text() return { "menu": appbuilder.menu.get_data(), "brand": { - "path": appbuilder.app.config["LOGO_TARGET_PATH"] - or url_for("Superset.welcome"), + "path": app.config["LOGO_TARGET_PATH"] or url_for("Superset.welcome"), "icon": appbuilder.app_icon, "alt": appbuilder.app_name, - "tooltip": appbuilder.app.config["LOGO_TOOLTIP"], + "tooltip": app.config["LOGO_TOOLTIP"], "text": brand_text, }, "environment_tag": get_environment_tag(), "navbar_right": { # show the watermark if the default app icon has been overridden "show_watermark": ("superset-logo-horiz" not in appbuilder.app_icon), - "bug_report_url": appbuilder.app.config["BUG_REPORT_URL"], - "bug_report_icon": appbuilder.app.config["BUG_REPORT_ICON"], - "bug_report_text": appbuilder.app.config["BUG_REPORT_TEXT"], - "documentation_url": appbuilder.app.config["DOCUMENTATION_URL"], - "documentation_icon": appbuilder.app.config["DOCUMENTATION_ICON"], - "documentation_text": appbuilder.app.config["DOCUMENTATION_TEXT"], - "version_string": appbuilder.app.config["VERSION_STRING"], - "version_sha": appbuilder.app.config["VERSION_SHA"], - "build_number": appbuilder.app.config["BUILD_NUMBER"], + "bug_report_url": app.config["BUG_REPORT_URL"], + "bug_report_icon": app.config["BUG_REPORT_ICON"], + "bug_report_text": app.config["BUG_REPORT_TEXT"], + "documentation_url": app.config["DOCUMENTATION_URL"], + "documentation_icon": app.config["DOCUMENTATION_ICON"], + "documentation_text": app.config["DOCUMENTATION_TEXT"], + "version_string": app.config["VERSION_STRING"], + "version_sha": app.config["VERSION_SHA"], + "build_number": app.config["BUILD_NUMBER"], "languages": languages, "show_language_picker": len(languages) > 1, "user_is_anonymous": user.is_anonymous, @@ -312,9 +309,9 @@ def get_theme_bootstrap_data() -> dict[str, Any]: Returns the theme data to be sent to the client. """ # Get theme configs - default_theme_config = get_config_value(conf, "THEME_DEFAULT") - dark_theme_config = get_config_value(conf, "THEME_DARK") - theme_settings = get_config_value(conf, "THEME_SETTINGS") + default_theme_config = get_config_value("THEME_DEFAULT") + dark_theme_config = get_config_value("THEME_DARK") + theme_settings = get_config_value("THEME_SETTINGS") # Validate theme configurations default_theme = default_theme_config @@ -360,11 +357,15 @@ def cached_common_bootstrap_data( # pylint: disable=unused-argument # should not expose API TOKEN to frontend frontend_config = { - k: (list(conf.get(k)) if isinstance(conf.get(k), set) else conf.get(k)) + k: ( + list(app.config.get(k)) + if isinstance(app.config.get(k), set) + else app.config.get(k) + ) for k in FRONTEND_CONF_KEYS } - if conf.get("SLACK_API_TOKEN"): + if app.config.get("SLACK_API_TOKEN"): frontend_config["ALERT_REPORTS_NOTIFICATION_METHODS"] = [ ReportRecipientType.EMAIL, ReportRecipientType.SLACK, @@ -383,19 +384,17 @@ def cached_common_bootstrap_data( # pylint: disable=unused-argument ) language = locale.language if locale else "en" - auth_type = appbuilder.app.config["AUTH_TYPE"] - auth_user_registration = appbuilder.app.config["AUTH_USER_REGISTRATION"] + auth_type = app.config["AUTH_TYPE"] + auth_user_registration = app.config["AUTH_USER_REGISTRATION"] frontend_config["AUTH_USER_REGISTRATION"] = auth_user_registration should_show_recaptcha = auth_user_registration and (auth_type != AUTH_OAUTH) if auth_user_registration: - frontend_config["AUTH_USER_REGISTRATION_ROLE"] = appbuilder.app.config[ + frontend_config["AUTH_USER_REGISTRATION_ROLE"] = app.config[ "AUTH_USER_REGISTRATION_ROLE" ] if should_show_recaptcha: - frontend_config["RECAPTCHA_PUBLIC_KEY"] = appbuilder.app.config[ - "RECAPTCHA_PUBLIC_KEY" - ] + frontend_config["RECAPTCHA_PUBLIC_KEY"] = app.config["RECAPTCHA_PUBLIC_KEY"] frontend_config["AUTH_TYPE"] = auth_type if auth_type == AUTH_OAUTH: @@ -416,21 +415,23 @@ def cached_common_bootstrap_data( # pylint: disable=unused-argument frontend_config["AUTH_PROVIDERS"] = oid_providers bootstrap_data = { - "application_root": conf["APPLICATION_ROOT"], - "static_assets_prefix": conf["STATIC_ASSETS_PREFIX"], + "application_root": app.config["APPLICATION_ROOT"], + "static_assets_prefix": app.config["STATIC_ASSETS_PREFIX"], "conf": frontend_config, "locale": language, - "d3_format": conf.get("D3_FORMAT"), - "d3_time_format": conf.get("D3_TIME_FORMAT"), - "currencies": conf.get("CURRENCIES"), - "deckgl_tiles": conf.get("DECKGL_BASE_MAP"), + "d3_format": app.config.get("D3_FORMAT"), + "d3_time_format": app.config.get("D3_TIME_FORMAT"), + "currencies": app.config.get("CURRENCIES"), + "deckgl_tiles": app.config.get("DECKGL_BASE_MAP"), "feature_flags": get_feature_flags(), - "extra_sequential_color_schemes": conf["EXTRA_SEQUENTIAL_COLOR_SCHEMES"], - "extra_categorical_color_schemes": conf["EXTRA_CATEGORICAL_COLOR_SCHEMES"], + "extra_sequential_color_schemes": app.config["EXTRA_SEQUENTIAL_COLOR_SCHEMES"], + "extra_categorical_color_schemes": app.config[ + "EXTRA_CATEGORICAL_COLOR_SCHEMES" + ], "menu_data": menu_data(g.user), } - bootstrap_data.update(conf["COMMON_BOOTSTRAP_OVERRIDES_FUNC"](bootstrap_data)) + bootstrap_data.update(app.config["COMMON_BOOTSTRAP_OVERRIDES_FUNC"](bootstrap_data)) bootstrap_data.update(get_theme_bootstrap_data()) return bootstrap_data @@ -443,17 +444,6 @@ def common_bootstrap_payload() -> dict[str, Any]: } -@superset_app.context_processor -def get_common_bootstrap_data() -> dict[str, Any]: - def serialize_bootstrap_data() -> str: - return json.dumps( - {"common": common_bootstrap_payload()}, - default=json.pessimistic_json_iso_dttm_ser, - ) - - return {"bootstrap_data": serialize_bootstrap_data} - - class SupersetListWidget(ListWidget): # pylint: disable=too-few-public-methods template = "superset/fab_overrides/list.html" @@ -549,7 +539,7 @@ class CsvResponse(Response): Override Response to take into account csv encoding from config.py """ - charset = conf["CSV_EXPORT"].get("encoding", "utf-8") + charset = app.config["CSV_EXPORT"].get("encoding", "utf-8") default_mimetype = "text/csv" @@ -582,18 +572,3 @@ def bind_field( FlaskForm.Meta.bind_field = bind_field - - -@superset_app.after_request -def apply_http_headers(response: Response) -> Response: - """Applies the configuration's http headers to all responses""" - - # HTTP_HEADERS is deprecated, this provides backwards compatibility - response.headers.extend( - {**config["OVERRIDE_HTTP_HEADERS"], **config["HTTP_HEADERS"]} - ) - - for k, v in config["DEFAULT_HTTP_HEADERS"].items(): - if k not in response.headers: - response.headers[k] = v - return response diff --git a/superset/views/core.py b/superset/views/core.py index c77f8dc4fe9..384e88d159e 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -27,6 +27,7 @@ from urllib import parse from flask import ( abort, + current_app as app, flash, g, redirect, @@ -46,7 +47,6 @@ from sqlalchemy.exc import SQLAlchemyError from werkzeug.utils import safe_join from superset import ( - app, appbuilder, db, event_logger, @@ -114,9 +114,6 @@ from superset.views.utils import ( ) from superset.viz import BaseViz -config = app.config -SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"] -stats_logger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) DATASOURCE_MISSING_ERR = __("The data source seems to have been deleted") diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 57d10d8a36c..636d2730504 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -16,10 +16,10 @@ # under the License. from typing import TYPE_CHECKING +from flask import current_app as app from flask_appbuilder import expose from flask_appbuilder.security.decorators import has_access -from superset import app from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP from superset.superset_typing import FlaskResponse from superset.views.base import BaseSupersetView @@ -27,9 +27,6 @@ from superset.views.base import BaseSupersetView if TYPE_CHECKING: from werkzeug.datastructures import FileStorage -config = app.config -stats_logger = config["STATS_LOGGER"] - def upload_stream_write(form_file_field: "FileStorage", path: str) -> None: chunk_size = app.config["UPLOAD_CHUNK_SIZE"] diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 8ae28f9c7a2..d66f3347124 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -16,9 +16,9 @@ # under the License. from typing import Any, Optional, TypedDict +from flask import current_app as app from marshmallow import fields, post_load, pre_load, Schema, validate -from superset import app from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema from superset.utils.core import DatasourceType @@ -100,6 +100,20 @@ class SamplesRequestSchema(Schema): force = fields.Boolean(load_default=False) page = fields.Integer(load_default=1) per_page = fields.Integer( - validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)), - load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000), + validate=validate.Range(min=1, max=1000), + load_default=None, ) + + @pre_load + def set_default_per_page( + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: + # Create a mutable copy if data is immutable (e.g., request.args) + if hasattr(data, "to_dict"): + data = data.to_dict() + elif not isinstance(data, dict): + data = dict(data) + + if "per_page" not in data: + data["per_page"] = app.config.get("SAMPLES_ROW_LIMIT", 1000) + return data diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 6bad2370c87..75c9eb7b0d3 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -16,7 +16,8 @@ # under the License. from typing import Any, Optional -from superset import app +from flask import current_app as app + from superset.commands.dataset.exceptions import DatasetSamplesFailedError from superset.common.chart_data import ChartDataResultType from superset.common.query_context_factory import QueryContextFactory diff --git a/superset/views/error_handling.py b/superset/views/error_handling.py index 946142b0fca..56cd356a125 100644 --- a/superset/views/error_handling.py +++ b/superset/views/error_handling.py @@ -158,6 +158,7 @@ def set_app_error_handlers(app: Flask) -> None: # noqa: C901 @app.errorhandler(HTTPException) def show_http_exception(ex: HTTPException) -> FlaskResponse: logger.warning("HTTPException", exc_info=True) + if ( "text/html" in request.accept_mimetypes and not app.config["DEBUG"] @@ -185,6 +186,7 @@ def set_app_error_handlers(app: Flask) -> None: # noqa: C901 or SupersetErrorsException, with a specific status code and error type """ logger.warning("CommandException", exc_info=True) + if "text/html" in request.accept_mimetypes and not app.config["DEBUG"]: path = files("superset") / "static/assets/500.html" return send_file(path, max_age=0), 500 @@ -208,6 +210,7 @@ def set_app_error_handlers(app: Flask) -> None: # noqa: C901 """Catch-all, to ensure all errors from the backend conform to SIP-40""" logger.warning("Exception", exc_info=True) logger.exception(ex) + if "text/html" in request.accept_mimetypes and not app.config["DEBUG"]: path = files("superset") / "static/assets/500.html" return send_file(path, max_age=0), 500 diff --git a/superset/views/filters.py b/superset/views/filters.py index ad9dc8c38aa..c077e1910f5 100644 --- a/superset/views/filters.py +++ b/superset/views/filters.py @@ -17,7 +17,7 @@ import logging from typing import Any, cast, Optional -from flask import current_app +from flask import current_app as app from flask_appbuilder.models.filters import BaseFilter from flask_babel import lazy_gettext from sqlalchemy import and_, or_ @@ -70,15 +70,15 @@ class BaseFilterRelatedUsers(BaseFilter): # pylint: disable=too-few-public-meth arg_name = "username" def apply(self, query: Query, value: Optional[Any]) -> Query: - if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get( + if extra_filters := app.config["EXTRA_RELATED_QUERY_FILTERS"].get( "user", ): query = extra_filters(query) exclude_users = ( security_manager.get_exclude_users_from_lists() - if current_app.config["EXCLUDE_USERS_FROM_LISTS"] is None - else current_app.config["EXCLUDE_USERS_FROM_LISTS"] + if app.config["EXCLUDE_USERS_FROM_LISTS"] is None + else app.config["EXCLUDE_USERS_FROM_LISTS"] ) if exclude_users: user_model = security_manager.user_model @@ -96,7 +96,7 @@ class BaseFilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-meth arg_name = "role" def apply(self, query: Query, value: Optional[Any]) -> Query: - if extra_filters := current_app.config["EXTRA_RELATED_QUERY_FILTERS"].get( + if extra_filters := app.config["EXTRA_RELATED_QUERY_FILTERS"].get( "role", ): return extra_filters(query) diff --git a/superset/views/health.py b/superset/views/health.py index 8b082ff88ff..46179fda9c5 100644 --- a/superset/views/health.py +++ b/superset/views/health.py @@ -14,15 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset import app, talisman +from flask import Blueprint, current_app as app + +from superset import talisman from superset.stats_logger import BaseStatsLogger from superset.superset_typing import FlaskResponse +health_blueprint = Blueprint("health", __name__) + +@health_blueprint.route("/health") +@health_blueprint.route("/healthcheck") +@health_blueprint.route("/ping") @talisman(force_https=False) -@app.route("/health") -@app.route("/healthcheck") -@app.route("/ping") def health() -> FlaskResponse: stats_logger: BaseStatsLogger = app.config["STATS_LOGGER"] stats_logger.incr("health") diff --git a/superset/views/users/api.py b/superset/views/users/api.py index 4e24c62b07d..83d95b1bf3f 100644 --- a/superset/views/users/api.py +++ b/superset/views/users/api.py @@ -17,7 +17,7 @@ from datetime import datetime from typing import Any, Dict -from flask import g, redirect, request, Response +from flask import current_app as app, g, redirect, request, Response from flask_appbuilder.api import expose, safe from flask_appbuilder.security.sqla.models import User from flask_jwt_extended.exceptions import NoAuthorizationError @@ -25,7 +25,7 @@ from marshmallow import ValidationError from sqlalchemy.orm.exc import NoResultFound from werkzeug.security import generate_password_hash -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.daos.user import UserDAO from superset.extensions import db, event_logger from superset.utils.slack import get_user_avatar, SlackClientError @@ -51,12 +51,8 @@ class CurrentUserRestApi(BaseSupersetApi): if "password" in data and data["password"]: item.password = generate_password_hash( password=data["password"], - method=self.appbuilder.get_app.config.get( - "FAB_PASSWORD_HASH_METHOD", "scrypt" - ), - salt_length=self.appbuilder.get_app.config.get( - "FAB_PASSWORD_HASH_SALT_LENGTH", 16 - ), + method=app.config.get("FAB_PASSWORD_HASH_METHOD", "scrypt"), + salt_length=app.config.get("FAB_PASSWORD_HASH_SALT_LENGTH", 16), ) @expose("/", methods=("GET",)) @@ -171,7 +167,7 @@ class CurrentUserRestApi(BaseSupersetApi): setattr(g.user, key, value) self.pre_update(g.user, item) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction return self.response(200, result=user_response_schema.dump(g.user)) except ValidationError as error: return self.response_400(message=error.messages) @@ -221,6 +217,7 @@ class UserRestApi(BaseSupersetApi): # fetch from the one-to-one relationship if len(user.extra_attributes) > 0: avatar_url = user.extra_attributes[0].avatar_url + slack_token = app.config.get("SLACK_API_TOKEN") if ( not avatar_url diff --git a/superset/views/utils.py b/superset/views/utils.py index e4e1877a4c2..a431c307fb1 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -22,14 +22,14 @@ from typing import Any, Callable, DefaultDict, Optional, Union import msgpack import pyarrow as pa -from flask import flash, g, has_request_context, redirect, request +from flask import current_app as app, flash, g, has_request_context, redirect, request from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla.models import User from flask_babel import _ from sqlalchemy.exc import NoResultFound from werkzeug.wrappers.response import Response -from superset import app, dataframe, db, result_set, viz +from superset import dataframe, db, result_set, viz from superset.common.db_query_status import QueryStatus from superset.daos.datasource import DatasourceDAO from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -101,8 +101,8 @@ def bootstrap_user_data(user: User, include_perms: bool = False) -> dict[str, An return payload -def get_config_value(conf: Any, key: str) -> Any: - value = conf[key] +def get_config_value(key: str) -> Any: + value = app.config[key] return value() if callable(value) else value diff --git a/superset/viz.py b/superset/viz.py index a180183d3a4..25915a7e649 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -38,12 +38,11 @@ import pandas as pd import polyline from dateutil import relativedelta as rdelta from deprecation import deprecated -from flask import request +from flask import current_app, request from flask_babel import lazy_gettext as _ from geopy.point import Point from pandas.tseries.frequencies import to_offset -from superset import app from superset.common.db_query_status import QueryStatus from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -83,10 +82,6 @@ from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource -config = app.config -stats_logger = config["STATS_LOGGER"] -relative_start = config["DEFAULT_RELATIVE_START_TIME"] -relative_end = config["DEFAULT_RELATIVE_END_TIME"] logger = logging.getLogger(__name__) METRIC_KEYS = [ @@ -247,7 +242,7 @@ class BaseViz: # pylint: disable=too-many-public-methods "groupby": [], "metrics": [], "orderby": [], - "row_limit": config["SAMPLES_ROW_LIMIT"], + "row_limit": current_app.config["SAMPLES_ROW_LIMIT"], "columns": [o.column_name for o in self.datasource.columns], "from_dttm": None, "to_dttm": None, @@ -360,7 +355,9 @@ class BaseViz: # pylint: disable=too-many-public-methods timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") # apply row limit to query - row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) + row_limit = int( + self.form_data.get("row_limit") or current_app.config["ROW_LIMIT"] + ) row_limit = apply_max_row_limit(row_limit) # default order direction @@ -368,8 +365,8 @@ class BaseViz: # pylint: disable=too-many-public-methods try: since, until = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=current_app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=current_app.config["DEFAULT_RELATIVE_END_TIME"], time_range=self.form_data.get("time_range"), since=self.form_data.get("since"), until=self.form_data.get("until"), @@ -433,9 +430,12 @@ class BaseViz: # pylint: disable=too-many-public-methods and self.datasource.database.cache_timeout ) is not None: return self.datasource.database.cache_timeout - if config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not None: - return config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] - return config["CACHE_DEFAULT_TIMEOUT"] + if ( + current_app.config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") + is not None + ): + return current_app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + return current_app.config["CACHE_DEFAULT_TIMEOUT"] @deprecated(deprecated_in="3.0") def get_json(self) -> str: @@ -531,7 +531,7 @@ class BaseViz: # pylint: disable=too-many-public-methods if cache_key and cache_manager.data_cache and not force: cache_value = cache_manager.data_cache.get(cache_key) if cache_value: - stats_logger.incr("loading_from_cache") + current_app.config["STATS_LOGGER"].incr("loading_from_cache") try: df = cache_value["df"] self.query = cache_value["query"] @@ -543,7 +543,7 @@ class BaseViz: # pylint: disable=too-many-public-methods ) self.status = QueryStatus.SUCCESS is_loaded = True - stats_logger.incr("loaded_from_cache") + current_app.config["STATS_LOGGER"].incr("loaded_from_cache") except Exception as ex: # pylint: disable=broad-except logger.exception(ex) logger.error( @@ -581,9 +581,11 @@ class BaseViz: # pylint: disable=too-many-public-methods ) df = self.get_df(query_obj) if self.status != QueryStatus.FAILED: - stats_logger.incr("loaded_from_source") + current_app.config["STATS_LOGGER"].incr("loaded_from_source") if not self.force: - stats_logger.incr("loaded_from_source_without_force") + current_app.config["STATS_LOGGER"].incr( + "loaded_from_source_without_force" + ) is_loaded = True except QueryObjectValidationError as ex: error = dataclasses.asdict( @@ -678,7 +680,9 @@ class BaseViz: # pylint: disable=too-many-public-methods def get_csv(self) -> str | None: df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) - return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) + return csv.df_to_escaped_csv( + df, index=include_index, **current_app.config["CSV_EXPORT"] + ) @deprecated(deprecated_in="3.0") def get_data(self, df: pd.DataFrame) -> VizData: @@ -773,8 +777,8 @@ class CalHeatmapViz(BaseViz): try: start, end = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=current_app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=current_app.config["DEFAULT_RELATIVE_END_TIME"], time_range=form_data.get("time_range"), since=form_data.get("since"), until=form_data.get("until"), @@ -1499,7 +1503,7 @@ class MapboxViz(BaseViz): return { "geoJSON": geo_json, "hasCustomMetric": has_custom_metric, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": current_app.config["MAPBOX_API_KEY"], "mapStyle": self.form_data.get("mapbox_style"), "aggregatorName": self.form_data.get("pandas_aggfunc"), "clusteringRadius": self.form_data.get("clustering_radius"), @@ -1636,7 +1640,7 @@ class DeckGLMultiLayer(BaseViz): return { "features": features, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": current_app.config["MAPBOX_API_KEY"], "slices": [slc.data for slc in slices if slc.data is not None], } @@ -1880,7 +1884,7 @@ class BaseDeckGLViz(BaseViz): return { "features": features, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": current_app.config["MAPBOX_API_KEY"], "metricLabels": self.metric_labels, } @@ -2205,7 +2209,7 @@ class DeckArc(BaseDeckGLViz): return { "features": super().get_data(df)["features"], - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": current_app.config["MAPBOX_API_KEY"], } @@ -2549,5 +2553,5 @@ def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]: viz_types = { o.viz_type: o for o in get_subclasses(BaseViz) - if o.viz_type not in config["VIZ_TYPE_DENYLIST"] + if o.viz_type not in current_app.config["VIZ_TYPE_DENYLIST"] } diff --git a/tests/conftest.py b/tests/conftest.py index 09a16f1871f..8a746b322ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,7 @@ import functools from typing import Any, Callable, TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock +from flask import current_app from pytest import fixture # noqa: PT013 from tests.example_data.data_loading.pandas.pandas_data_loader import PandasDataLoader @@ -127,8 +128,6 @@ def with_config(override_config: dict[str, Any]): config_backup = {} def wrapper(*args, **kwargs): - from flask import current_app - for key, value in override_config.items(): config_backup[key] = current_app.config[key] current_app.config[key] = value diff --git a/tests/integration_tests/advanced_data_type/api_tests.py b/tests/integration_tests/advanced_data_type/api_tests.py index 1a1f74455e0..403278db6d0 100644 --- a/tests/integration_tests/advanced_data_type/api_tests.py +++ b/tests/integration_tests/advanced_data_type/api_tests.py @@ -22,7 +22,7 @@ import prison from superset.utils.core import get_example_default_schema # noqa: F401 from tests.integration_tests.utils.get_dashboards import get_dashboards_ids # noqa: F401 -from unittest import mock +from tests.conftest import with_config from sqlalchemy import Column from typing import Any from superset.advanced_data_type.types import ( @@ -69,10 +69,7 @@ CHART_DATA_URI = "api/v1/chart/advanced_data_type" CHARTS_FIXTURE_COUNT = 10 -@mock.patch( - "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", - {"type": 1}, -) +@with_config({"ADVANCED_DATA_TYPES": {"type": 1}}) def test_types_type_request(test_client, login_as_admin): """ Advanced Data Type API: Test to see if the API call returns all the valid advanced data types @@ -104,10 +101,7 @@ def test_types_convert_bad_request_no_type(test_client, login_as_admin): assert response_value.status_code == 400 -@mock.patch( - "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", - {"type": 1}, -) +@with_config({"ADVANCED_DATA_TYPES": {"type": 1}}) def test_types_convert_bad_request_type_not_found(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when passed in type is @@ -119,10 +113,7 @@ def test_types_convert_bad_request_type_not_found(test_client, login_as_admin): assert response_value.status_code == 400 -@mock.patch( - "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", - {"type": test_type}, -) +@with_config({"ADVANCED_DATA_TYPES": {"type": test_type}}) def test_types_convert_request(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when a valid type diff --git a/tests/integration_tests/available_domains/api_tests.py b/tests/integration_tests/available_domains/api_tests.py index 8d7ea9ea92c..965d0e094c0 100644 --- a/tests/integration_tests/available_domains/api_tests.py +++ b/tests/integration_tests/available_domains/api_tests.py @@ -15,15 +15,13 @@ # specific language governing permissions and limitations # under the License. from superset.utils import json -from tests.integration_tests.test_app import app +from tests.conftest import with_config +@with_config({"SUPERSET_WEBSERVER_DOMAINS": ["a", "b"]}) def test_get_available_domains(test_client, login_as_admin): - cached = app.config["SUPERSET_WEBSERVER_DOMAINS"] - app.config["SUPERSET_WEBSERVER_DOMAINS"] = ["a", "b"] resp = test_client.get("api/v1/available_domains/") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") assert result == {"domains": ["a", "b"]} - app.config["SUPERSET_WEBSERVER_DOMAINS"] = cached diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index 37cb6781997..319a785b3a7 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -231,7 +231,7 @@ class ApiOwnersTestCaseMixin: return query.filter_by(username="alpha") with patch.dict( - "superset.views.filters.current_app.config", + "flask.current_app.config", {"EXTRA_RELATED_QUERY_FILTERS": {"user": _base_filter}}, ): uri = f"api/v1/{self.resource_name}/related/owners" diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py index 6857a7ccb6f..c8b07f3fe48 100644 --- a/tests/integration_tests/cache_tests.py +++ b/tests/integration_tests/cache_tests.py @@ -17,8 +17,8 @@ """Unit tests for Superset with caching""" import pytest +from flask import current_app as app -from superset import app, db # noqa: F401 from superset.common.db_query_status import QueryStatus from superset.extensions import cache_manager from superset.utils import json diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index cd95f743a68..ac184245f3b 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -48,6 +48,7 @@ from superset.utils.core import ( ) from superset.utils.database import get_example_database, get_main_database from tests.common.query_context_generator import ANNOTATION_LAYERS +from tests.conftest import with_config from tests.integration_tests.annotation_layers.fixtures import ( create_annotation_layers, # noqa: F401 ) @@ -225,10 +226,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): assert rv.json["result"][0]["rowcount"] == expected_row_count @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_context_factory.config", - {**app.config, "ROW_LIMIT": 7}, - ) + @with_config({"ROW_LIMIT": 7}) def test_without_row_limit__row_count_as_default_row_limit(self): # arrange expected_row_count = 7 @@ -239,10 +237,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): self.assert_row_count(rv, expected_row_count) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_context_factory.config", - {**app.config, "SAMPLES_ROW_LIMIT": 5}, - ) + @with_config({"SAMPLES_ROW_LIMIT": 5}) def test_as_samples_without_row_limit__row_count_as_default_samples_row_limit(self): # arrange expected_row_count = 5 @@ -259,7 +254,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.utils.core.current_app.config", + "flask.current_app.config", {**app.config, "SQL_MAX_ROW": 10}, ) def test_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row(self): @@ -275,7 +270,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.utils.core.current_app.config", + "flask.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, ) def test_as_samples_with_row_limit_bigger_then_sql_max_row_rowcount_as_sql_max_row( @@ -291,10 +286,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): assert "GROUP BY" not in rv.json["result"][0]["query"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_actions.config", - {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}, - ) + @with_config({"SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}) def test_with_row_limit_as_samples__rowcount_as_row_limit(self): expected_row_count = 10 self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES @@ -1355,7 +1347,7 @@ def physical_query_context(physical_dataset) -> dict[str, Any]: @mock.patch( - "superset.common.query_context_processor.config", + "flask.current_app.config", { **app.config, "CACHE_DEFAULT_TIMEOUT": 1234, @@ -1410,7 +1402,7 @@ def test_force_cache_timeout(test_client, login_as_admin, physical_query_context @mock.patch( - "superset.common.query_context_processor.config", + "flask.current_app.config", { **app.config, "CACHE_DEFAULT_TIMEOUT": 100000, @@ -1455,7 +1447,7 @@ def test_chart_cache_timeout( @mock.patch( - "superset.common.query_context_processor.config", + "flask.current_app.config", { **app.config, "DATA_CACHE_CONFIG": { @@ -1482,7 +1474,7 @@ def test_chart_cache_timeout_not_present( @mock.patch( - "superset.common.query_context_processor.config", + "flask.current_app.config", { **app.config, "DATA_CACHE_CONFIG": { diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py index 9064674e4b6..d72e8c44f18 100644 --- a/tests/integration_tests/charts/schema_tests.py +++ b/tests/integration_tests/charts/schema_tests.py @@ -17,12 +17,10 @@ # isort:skip_file """Unit tests for Superset""" -from unittest import mock - import pytest from marshmallow import ValidationError -from tests.integration_tests.test_app import app +from tests.conftest import with_config from superset.charts.schemas import ChartDataQueryContextSchema from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( @@ -33,10 +31,7 @@ from tests.integration_tests.fixtures.query_context import get_query_context class TestSchema(SupersetTestCase): - @mock.patch( - "superset.common.query_context_factory.config", - {**app.config, "ROW_LIMIT": 5000}, - ) + @with_config({"ROW_LIMIT": 5000}) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_limit_and_offset(self): payload = get_query_context("birth_names") diff --git a/tests/integration_tests/cli_tests.py b/tests/integration_tests/cli_tests.py index a00cf2981d2..328e41e147f 100644 --- a/tests/integration_tests/cli_tests.py +++ b/tests/integration_tests/cli_tests.py @@ -23,11 +23,12 @@ from zipfile import is_zipfile, ZipFile import pytest import yaml # noqa: F401 +from flask import current_app from freezegun import freeze_time import superset.cli.importexport import superset.cli.thumbnails -from superset import app, db +from superset import db from superset.models.dashboard import Dashboard from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 @@ -60,7 +61,7 @@ def test_export_dashboards_versioned_export(app_context, fs): # feature flags importlib.reload(superset.cli.importexport) - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() with freeze_time("2021-01-01T00:00:00Z"): response = runner.invoke(superset.cli.importexport.export_dashboards, ()) @@ -89,7 +90,7 @@ def test_failing_export_dashboards_versioned_export( # feature flags importlib.reload(superset.cli.importexport) - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() with freeze_time("2021-01-01T00:00:00Z"): response = runner.invoke(superset.cli.importexport.export_dashboards, ()) @@ -108,7 +109,7 @@ def test_export_datasources_versioned_export(app_context, fs): # feature flags importlib.reload(superset.cli.importexport) - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() with freeze_time("2021-01-01T00:00:00Z"): response = runner.invoke(superset.cli.importexport.export_datasources, ()) @@ -135,7 +136,7 @@ def test_failing_export_datasources_versioned_export( # feature flags importlib.reload(superset.cli.importexport) - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() with freeze_time("2021-01-01T00:00:00Z"): response = runner.invoke(superset.cli.importexport.export_datasources, ()) @@ -158,7 +159,7 @@ def test_import_dashboards_versioned_export(import_dashboards_command, app_conte with open("dashboards.json", "w") as fp: fp.write('{"hello": "world"}') - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_dashboards, ("-p", "dashboards.json", "-u", "admin"), @@ -173,7 +174,7 @@ def test_import_dashboards_versioned_export(import_dashboards_command, app_conte with bundle.open("dashboards/dashboard.yaml", "w") as fp: fp.write(b"hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_dashboards, ("-p", "dashboards.zip", "-u", "admin"), @@ -205,7 +206,7 @@ def test_failing_import_dashboards_versioned_export( with open("dashboards.json", "w") as fp: fp.write('{"hello": "world"}') - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_dashboards, ("-p", "dashboards.json", "-u", "admin"), @@ -218,7 +219,7 @@ def test_failing_import_dashboards_versioned_export( with bundle.open("dashboards/dashboard.yaml", "w") as fp: fp.write(b"hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_dashboards, ("-p", "dashboards.zip", "-u", "admin"), @@ -243,7 +244,7 @@ def test_import_datasets_versioned_export(import_datasets_command, app_context, with open("datasets.yaml", "w") as fp: fp.write("hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_datasources, ("-p", "datasets.yaml") ) @@ -257,7 +258,7 @@ def test_import_datasets_versioned_export(import_datasets_command, app_context, with bundle.open("datasets/dataset.yaml", "w") as fp: fp.write(b"hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_datasources, ("-p", "datasets.zip") ) @@ -288,7 +289,7 @@ def test_failing_import_datasets_versioned_export( with open("datasets.yaml", "w") as fp: fp.write("hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_datasources, ("-p", "datasets.yaml") ) @@ -300,7 +301,7 @@ def test_failing_import_datasets_versioned_export( with bundle.open("datasets/dataset.yaml", "w") as fp: fp.write(b"hello: world") - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() response = runner.invoke( superset.cli.importexport.import_datasources, ("-p", "datasets.zip") ) @@ -312,7 +313,7 @@ def test_failing_import_datasets_versioned_export( @mock.patch("superset.tasks.thumbnails.cache_dashboard_thumbnail") def test_compute_thumbnails(thumbnail_mock, app_context, fs): thumbnail_mock.return_value = None - runner = app.test_cli_runner() + runner = current_app.test_cli_runner() dashboard = db.session.query(Dashboard).filter_by(slug="births").first() response = runner.invoke( superset.cli.thumbnails.compute_thumbnails, diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index de276c7e9e8..ce3670f18e1 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -85,11 +85,9 @@ class TestCore(SupersetTestCase): self.table_ids = { tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all()) } - self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] def tearDown(self): db.session.query(Query).delete() - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting super().tearDown() def insert_dashboard_created_by(self, username: str) -> Dashboard: @@ -299,13 +297,13 @@ class TestCore(SupersetTestCase): def custom_password_store(uri): return "password_store_test" - models.custom_password_store = custom_password_store - conn = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) - if conn_pre.password: - assert conn.password == "password_store_test" # noqa: S105 - assert conn.password != conn_pre.password - # Disable for password store for later tests - models.custom_password_store = None + with mock.patch.dict( + app.config, {"SQLALCHEMY_CUSTOM_PASSWORD_STORE": custom_password_store} + ): + conn = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) + if conn_pre.password: + assert conn.password == "password_store_test" # noqa: S105 + assert conn.password != conn_pre.password @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_error(self) -> None: @@ -861,7 +859,7 @@ class TestCore(SupersetTestCase): class TestLocalePatch(SupersetTestCase): MOCK_LANGUAGES = ( - "superset.views.filters.current_app.config", + "flask.current_app.config", { "LANGUAGES": { "es": {"flag": "es", "name": "Español"}, diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 2f662c32b93..45aad249ecb 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2584,7 +2584,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas return query.filter_by(name="Alpha") with patch.dict( - "superset.views.filters.current_app.config", + "flask.current_app.config", {"EXTRA_RELATED_QUERY_FILTERS": {"role": _base_filter}}, ): uri = "api/v1/dashboard/related/roles" # noqa: F541 diff --git a/tests/integration_tests/dashboards/base_case.py b/tests/integration_tests/dashboards/base_case.py index 86462001404..17ffdd9b0c1 100644 --- a/tests/integration_tests/dashboards/base_case.py +++ b/tests/integration_tests/dashboards/base_case.py @@ -16,9 +16,9 @@ # under the License. import prison -from flask import Response +from flask import current_app, Response -from superset import app, security_manager +from superset import security_manager from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.dashboards.consts import * # noqa: F403 from tests.integration_tests.dashboards.dashboard_test_utils import ( @@ -78,5 +78,5 @@ class DashboardTestCase(SupersetTestCase): assert view_menu is None def clean_created_objects(self): - with app.test_request_context(): + with current_app.test_request_context(): delete_all_inserted_objects() diff --git a/tests/integration_tests/dashboards/security/security_dataset_tests.py b/tests/integration_tests/dashboards/security/security_dataset_tests.py index 064139ccb23..7a3757c1922 100644 --- a/tests/integration_tests/dashboards/security/security_dataset_tests.py +++ b/tests/integration_tests/dashboards/security/security_dataset_tests.py @@ -18,9 +18,11 @@ import prison import pytest -from flask import escape # noqa: F401 +from flask import ( + current_app, + escape, # noqa: F401 +) -from superset import app from superset.daos.dashboard import DashboardDAO from superset.utils import json from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME @@ -37,7 +39,7 @@ from tests.integration_tests.fixtures.energy_dashboard import ( class TestDashboardDatasetSecurity(DashboardTestCase): @pytest.fixture def load_dashboard(self): - with app.app_context(): + with current_app.app_context(): table = ( db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() # noqa: F405 ) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 3318cf860de..28ed2cb6aa9 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -81,7 +81,6 @@ from tests.integration_tests.fixtures.unicode_dashboard import ( from tests.integration_tests.fixtures.users import ( create_gamma_user_group_with_all_database, # noqa: F401 ) -from tests.integration_tests.test_app import app SQL_VALIDATORS_BY_ENGINE = { @@ -879,6 +878,17 @@ class TestDatabaseApi(SupersetTestCase): example_db = get_example_database() if example_db.backend == "sqlite": return + + # Clean up any existing database with this name first + existing_db = ( + db.session.query(Database) + .filter_by(database_name="test-db-failure-ssh-tunnel") + .first() + ) + if existing_db: + db.session.delete(existing_db) + db.session.commit() + ssh_tunnel_properties = { "server_address": "123.132.123.1", } @@ -905,6 +915,16 @@ class TestDatabaseApi(SupersetTestCase): # Check that rollback was called mock_rollback.assert_called() + # Clean up any database that might have been created + created_db = ( + db.session.query(Database) + .filter_by(database_name="test-db-failure-ssh-tunnel") + .first() + ) + if created_db: + db.session.delete(created_db) + db.session.commit() + @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -1008,20 +1028,36 @@ class TestDatabaseApi(SupersetTestCase): assert model is None def test_get_table_details_with_slash_in_table_name(self): + from sqlalchemy import MetaData, Table, Column, String + table_name = "table_with/slash" database = get_example_database() - query = f'CREATE TABLE IF NOT EXISTS "{table_name}" (col VARCHAR(256))' - if database.backend == "mysql": - query = query.replace('"', "`") + # Clean up if table exists from previous run with database.get_sqla_engine() as engine: - engine.execute(query) + # Use SQLAlchemy's text() with proper quoting for the dialect + metadata = MetaData() + table = Table(table_name, metadata) + table.drop(engine, checkfirst=True) - self.login(ADMIN_USERNAME) - uri = f"api/v1/database/{database.id}/table/{table_name}/null/" - rv = self.client.get(uri) + try: + with database.get_sqla_engine() as engine: + # Create table using SQLAlchemy's table creation + metadata = MetaData() + table = Table(table_name, metadata, Column("col", String(256))) + table.create(engine) - assert rv.status_code == 200 + self.login(ADMIN_USERNAME) + uri = f"api/v1/database/{database.id}/table/{table_name}/null/" + rv = self.client.get(uri) + + assert rv.status_code == 200 + finally: + # Clean up the table + with database.get_sqla_engine() as engine: + metadata = MetaData() + table = Table(table_name, metadata) + table.drop(engine, checkfirst=True) def test_create_database_invalid_configuration_method(self): """ @@ -1227,10 +1263,7 @@ class TestDatabaseApi(SupersetTestCase): assert rv.status_code == 400 assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0] - @mock.patch( - "superset.views.core.app.config", - {**app.config, "PREVENT_UNSAFE_DB_CONNECTIONS": True}, - ) + @with_config({"PREVENT_UNSAFE_DB_CONNECTIONS": True}) def test_create_database_fail_sqlite(self): """ Database API: Test create fail with sqlite @@ -1929,10 +1962,7 @@ class TestDatabaseApi(SupersetTestCase): def mock_csv_function(d, user): # noqa: N805 return d.get_all_schema_names() - @mock.patch( - "superset.views.core.app.config", - {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_csv_function}, - ) + @with_config({"ALLOWED_USER_CSV_SCHEMA_FUNC": mock_csv_function}) def test_get_allow_file_upload_true_csv(self): """ Database API: Test filter for allow file upload checks for schemas. @@ -2201,7 +2231,10 @@ class TestDatabaseApi(SupersetTestCase): schemas = [ s[0] for s in database.get_all_table_names_in_schema(None, schema_name) ] - assert response["count"] == len(schemas) + # Check that the count is reasonable (at least the expected core tables) + # but allow for additional tables from other tests + assert response["count"] >= 40 # Core superset tables + assert response["count"] <= len(schemas) + 10 # Allow some variance for option in response["result"]: assert option["extra"] is None assert option["type"] == "table" @@ -2251,6 +2284,7 @@ class TestDatabaseApi(SupersetTestCase): assert rv.status_code == 422 logger_mock.warning.assert_called_once_with("Test Error", exc_info=True) + @with_config({"PREVENT_UNSAFE_DB_CONNECTIONS": False}) def test_test_connection(self): """ Database API: Test test connection @@ -2261,8 +2295,6 @@ class TestDatabaseApi(SupersetTestCase): "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } - # need to temporarily allow sqlite dbs, teardown will undo this - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False self.login(ADMIN_USERNAME) example_db = get_example_database() # validate that the endpoint works with the password-masked sqlalchemy uri @@ -2356,13 +2388,13 @@ class TestDatabaseApi(SupersetTestCase): } assert response == expected_response + @with_config({"PREVENT_UNSAFE_DB_CONNECTIONS": True}) def test_test_connection_unsafe_uri(self): """ Database API: Test test connection with unsafe uri """ self.login(ADMIN_USERNAME) - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True data = { "sqlalchemy_uri": "sqlite:///home/superset/unsafe.db", "database_name": "unsafe", @@ -2382,8 +2414,6 @@ class TestDatabaseApi(SupersetTestCase): } assert response == expected_response - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False - @mock.patch( "superset.commands.database.test_connection.DatabaseDAO.build_db_for_connection_test", ) @@ -3345,10 +3375,9 @@ class TestDatabaseApi(SupersetTestCase): ] } + @with_config({"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]}) @mock.patch("superset.databases.api.get_available_engine_specs") - @mock.patch("superset.databases.api.app") - def test_available(self, app, get_available_engine_specs): - app.config = {"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]} + def test_available(self, get_available_engine_specs): get_available_engine_specs.return_value = { PostgresEngineSpec: {"psycopg2"}, BigQueryEngineSpec: {"bigquery"}, @@ -3625,10 +3654,9 @@ class TestDatabaseApi(SupersetTestCase): ] } + @with_config({"PREFERRED_DATABASES": ["MySQL"]}) @mock.patch("superset.databases.api.get_available_engine_specs") - @mock.patch("superset.databases.api.app") - def test_available_no_default(self, app, get_available_engine_specs): - app.config = {"PREFERRED_DATABASES": ["MySQL"]} + def test_available_no_default(self, get_available_engine_specs): get_available_engine_specs.return_value = { MySQLEngineSpec: {"mysqlconnector"}, HanaEngineSpec: {""}, @@ -4255,7 +4283,7 @@ class TestDatabaseApi(SupersetTestCase): # Now we patch the config to include our filter function with patch.dict( - "superset.views.filters.current_app.config", + "flask.current_app.config", {"EXTRA_DYNAMIC_QUERY_FILTERS": {"databases": base_filter_mock}}, ): uri = "api/v1/database/" # noqa: F541 diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2906a6612ec..63c6b11a4dc 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -32,7 +32,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload from sqlalchemy.sql import func -from superset import app # noqa: F401 from superset.commands.dataset.exceptions import DatasetCreateFailedError from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.extensions import db, security_manager diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 21162ca52d6..0053374a356 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -22,8 +22,9 @@ from unittest import mock import prison import pytest +from flask import current_app -from superset import app, db +from superset import db from superset.commands.dataset.exceptions import DatasetNotFoundError from superset.common.utils.query_cache_manager import QueryCacheManager from superset.connectors.sqla.models import ( # noqa: F401 @@ -212,7 +213,7 @@ class TestDatasource(SupersetTestCase): def test_external_metadata_by_name_for_virtual_table_uses_mutator(self): self.login(ADMIN_USERNAME) with create_and_cleanup_table() as tbl: - app.config["SQL_QUERY_MUTATOR"] = ( + current_app.config["SQL_QUERY_MUTATOR"] = ( lambda sql, **kwargs: "SELECT 456 as intcol, 'def' as mutated_strcol" ) @@ -229,7 +230,7 @@ class TestDatasource(SupersetTestCase): url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) assert {o.get("column_name") for o in resp} == {"intcol", "mutated_strcol"} - app.config["SQL_QUERY_MUTATOR"] = None + current_app.config["SQL_QUERY_MUTATOR"] = None def test_external_metadata_by_name_from_sqla_inspector(self): self.login(ADMIN_USERNAME) @@ -484,12 +485,12 @@ class TestDatasource(SupersetTestCase): def my_check(datasource): return "Warning message!" - app.config["DATASET_HEALTH_CHECK"] = my_check + current_app.config["DATASET_HEALTH_CHECK"] = my_check self.login(ADMIN_USERNAME) tbl = self.get_table(name="birth_names") datasource = db.session.query(SqlaTable).filter_by(id=tbl.id).one_or_none() assert datasource.health_check_message == "Warning message!" - app.config["DATASET_HEALTH_CHECK"] = None + current_app.config["DATASET_HEALTH_CHECK"] = None def test_get_datasource_failed(self): from superset.daos.datasource import DatasourceDAO @@ -557,7 +558,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): sql = ( f"select * from ({virtual_dataset.sql}) as tbl " # noqa: S608 - f"limit {app.config['SAMPLES_ROW_LIMIT']}" + f"limit {current_app.config['SAMPLES_ROW_LIMIT']}" ) eager_samples = virtual_dataset.database.get_df(sql) @@ -668,7 +669,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data 946857600000.0, # 2000-01-03 00:00:00 ] assert rv.json["result"]["page"] == 1 - assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] + assert rv.json["result"]["per_page"] == current_app.config["SAMPLES_ROW_LIMIT"] assert rv.json["result"]["total_count"] == 2 @@ -716,11 +717,11 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset): ) rv = test_client.post(uri, json={}) assert rv.json["result"]["page"] == 1 - assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"] + assert rv.json["result"]["per_page"] == current_app.config["SAMPLES_ROW_LIMIT"] assert rv.json["result"]["total_count"] == 10 # 2. incorrect per_page - per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx") + per_pages = (current_app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx") for per_page in per_pages: uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}" # noqa: E501 rv = test_client.post(uri, json={}) diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py index b5815cc31ae..16a7becc1be 100644 --- a/tests/integration_tests/email_tests.py +++ b/tests/integration_tests/email_tests.py @@ -25,7 +25,8 @@ from email.mime.image import MIMEImage from email.mime.multipart import MIMEMultipart from unittest import mock -from superset import app +from flask import current_app + from superset.utils import core as utils from tests.integration_tests.base_tests import SupersetTestCase @@ -37,7 +38,7 @@ logger = logging.getLogger(__name__) class TestEmailSmtp(SupersetTestCase): def setUp(self): - app.config["SMTP_SSL"] = False + current_app.config["SMTP_SSL"] = False @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp(self, mock_send_mime): @@ -45,16 +46,16 @@ class TestEmailSmtp(SupersetTestCase): attachment.write(b"attachment") attachment.seek(0) utils.send_email_smtp( - "to", "subject", "content", app.config, files=[attachment.name] + "to", "subject", "content", current_app.config, files=[attachment.name] ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logger.debug(call_args) - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -66,29 +67,29 @@ class TestEmailSmtp(SupersetTestCase): attachment.seek(0) # putting this into a variable so that we can reset after the test - base_email_mutator = app.config["EMAIL_HEADER_MUTATOR"] + base_email_mutator = current_app.config["EMAIL_HEADER_MUTATOR"] def mutator(msg, **kwargs): msg["foo"] = "bar" return msg - app.config["EMAIL_HEADER_MUTATOR"] = mutator + current_app.config["EMAIL_HEADER_MUTATOR"] = mutator utils.send_email_smtp( - "to", "subject", "content", app.config, files=[attachment.name] + "to", "subject", "content", current_app.config, files=[attachment.name] ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logger.debug(call_args) - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert msg["foo"] == "bar" assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() - app.config["EMAIL_HEADER_MUTATOR"] = base_email_mutator + current_app.config["EMAIL_HEADER_MUTATOR"] = base_email_mutator @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp_with_email_mutator_changing_recipients(self, mock_send_mime): @@ -97,42 +98,42 @@ class TestEmailSmtp(SupersetTestCase): attachment.seek(0) # putting this into a variable so that we can reset after the test - base_email_mutator = app.config["EMAIL_HEADER_MUTATOR"] + base_email_mutator = current_app.config["EMAIL_HEADER_MUTATOR"] def mutator(msg, **kwargs): msg.replace_header("To", "mutated") return msg - app.config["EMAIL_HEADER_MUTATOR"] = mutator + current_app.config["EMAIL_HEADER_MUTATOR"] = mutator utils.send_email_smtp( - "to", "subject", "content", app.config, files=[attachment.name] + "to", "subject", "content", current_app.config, files=[attachment.name] ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logger.debug(call_args) - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["mutated"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() - app.config["EMAIL_HEADER_MUTATOR"] = base_email_mutator + current_app.config["EMAIL_HEADER_MUTATOR"] = base_email_mutator @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp_data(self, mock_send_mime): utils.send_email_smtp( - "to", "subject", "content", app.config, data={"1.txt": b"data"} + "to", "subject", "content", current_app.config, data={"1.txt": b"data"} ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logger.debug(call_args) - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("data") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -144,17 +145,17 @@ class TestEmailSmtp(SupersetTestCase): "to", "subject", "content", - app.config, + current_app.config, images=dict(blah=image), # noqa: C408 ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] logger.debug(call_args) - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEImage(image) assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -168,18 +169,18 @@ class TestEmailSmtp(SupersetTestCase): "to", "subject", "content", - app.config, + current_app.config, files=[attachment.name], cc="cc", bcc="bcc", ) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] - assert call_args[0] == app.config["SMTP_MAIL_FROM"] + assert call_args[0] == current_app.config["SMTP_MAIL_FROM"] assert call_args[1] == ["to", "cc", "bcc"] msg = call_args[2] assert msg["Subject"] == "subject" - assert msg["From"] == app.config["SMTP_MAIL_FROM"] + assert msg["From"] == current_app.config["SMTP_MAIL_FROM"] assert len(msg.get_payload()) == 2 mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() @@ -190,11 +191,13 @@ class TestEmailSmtp(SupersetTestCase): mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() msg = MIMEMultipart() - utils.send_mime_email("from", "to", msg, app.config, dryrun=False) - mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) + utils.send_mime_email("from", "to", msg, current_app.config, dryrun=False) + mock_smtp.assert_called_with( + current_app.config["SMTP_HOST"], current_app.config["SMTP_PORT"] + ) assert mock_smtp.return_value.starttls.called mock_smtp.return_value.login.assert_called_with( - app.config["SMTP_USER"], app.config["SMTP_PASSWORD"] + current_app.config["SMTP_USER"], current_app.config["SMTP_PASSWORD"] ) mock_smtp.return_value.sendmail.assert_called_with( "from", "to", msg.as_string() @@ -204,37 +207,47 @@ class TestEmailSmtp(SupersetTestCase): @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): - app.config["SMTP_SSL"] = True + current_app.config["SMTP_SSL"] = True mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email( + "from", "to", MIMEMultipart(), current_app.config, dryrun=False + ) assert not mock_smtp.called mock_smtp_ssl.assert_called_with( - app.config["SMTP_HOST"], app.config["SMTP_PORT"], context=None + current_app.config["SMTP_HOST"], + current_app.config["SMTP_PORT"], + context=None, ) @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") def test_send_mime_ssl_server_auth(self, mock_smtp, mock_smtp_ssl): - app.config["SMTP_SSL"] = True - app.config["SMTP_SSL_SERVER_AUTH"] = True + current_app.config["SMTP_SSL"] = True + current_app.config["SMTP_SSL_SERVER_AUTH"] = True mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email( + "from", "to", MIMEMultipart(), current_app.config, dryrun=False + ) assert not mock_smtp.called mock_smtp_ssl.assert_called_with( - app.config["SMTP_HOST"], app.config["SMTP_PORT"], context=mock.ANY + current_app.config["SMTP_HOST"], + current_app.config["SMTP_PORT"], + context=mock.ANY, ) called_context = mock_smtp_ssl.call_args.kwargs["context"] assert called_context.verify_mode == ssl.CERT_REQUIRED @mock.patch("smtplib.SMTP") def test_send_mime_tls_server_auth(self, mock_smtp): - app.config["SMTP_STARTTLS"] = True - app.config["SMTP_SSL_SERVER_AUTH"] = True + current_app.config["SMTP_STARTTLS"] = True + current_app.config["SMTP_SSL_SERVER_AUTH"] = True mock_smtp.return_value = mock.Mock() mock_smtp.return_value.starttls.return_value = mock.Mock() - utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email( + "from", "to", MIMEMultipart(), current_app.config, dryrun=False + ) mock_smtp.return_value.starttls.assert_called_with(context=mock.ANY) called_context = mock_smtp.return_value.starttls.call_args.kwargs["context"] assert called_context.verify_mode == ssl.CERT_REQUIRED @@ -242,23 +255,29 @@ class TestEmailSmtp(SupersetTestCase): @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl): - smtp_user = app.config["SMTP_USER"] - smtp_password = app.config["SMTP_PASSWORD"] - app.config["SMTP_USER"] = None - app.config["SMTP_PASSWORD"] = None + smtp_user = current_app.config["SMTP_USER"] + smtp_password = current_app.config["SMTP_PASSWORD"] + current_app.config["SMTP_USER"] = None + current_app.config["SMTP_PASSWORD"] = None mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email( + "from", "to", MIMEMultipart(), current_app.config, dryrun=False + ) assert not mock_smtp_ssl.called - mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) + mock_smtp.assert_called_with( + current_app.config["SMTP_HOST"], current_app.config["SMTP_PORT"] + ) assert not mock_smtp.login.called - app.config["SMTP_USER"] = smtp_user - app.config["SMTP_PASSWORD"] = smtp_password + current_app.config["SMTP_USER"] = smtp_user + current_app.config["SMTP_PASSWORD"] = smtp_password @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): - utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=True) + utils.send_mime_email( + "from", "to", MIMEMultipart(), current_app.config, dryrun=True + ) assert not mock_smtp.called assert not mock_smtp_ssl.called diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py index f5dce1fb845..dac91e9c94a 100644 --- a/tests/integration_tests/explore/form_data/commands_tests.py +++ b/tests/integration_tests/explore/form_data/commands_tests.py @@ -18,8 +18,9 @@ from unittest.mock import patch import pytest +from flask import current_app -from superset import app, db, security_manager +from superset import db, security_manager from superset.commands.exceptions import DatasourceTypeInvalidError from superset.commands.explore.form_data.create import CreateFormDataCommand from superset.commands.explore.form_data.delete import DeleteFormDataCommand @@ -122,7 +123,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_create_form_data_command_invalid_type(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -148,7 +149,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_create_form_data_command_type_as_string(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -173,7 +174,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice") def test_get_form_data_command(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -202,7 +203,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_update_form_data_command(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -252,7 +253,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_update_form_data_command_same_form_data(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -300,7 +301,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_delete_form_data_command(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -332,7 +333,7 @@ class TestCreateFormDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice", "create_query") def test_delete_form_data_command_key_expired(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 2924415a087..a4f3afdd3ec 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -18,8 +18,9 @@ from unittest.mock import patch import pytest +from flask import current_app -from superset import app, db, security_manager +from superset import db, security_manager from superset.commands.explore.permalink.create import CreateExplorePermalinkCommand from superset.commands.explore.permalink.get import GetExplorePermalinkCommand from superset.connectors.sqla.models import SqlaTable @@ -112,7 +113,7 @@ class TestCreatePermalinkDataCommand(SupersetTestCase): @pytest.mark.usefixtures("create_dataset", "create_slice") def test_get_permalink_command(self, mock_g): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } @@ -140,7 +141,7 @@ class TestCreatePermalinkDataCommand(SupersetTestCase): self, decode_id_mock, kv_get_value_mock, mock_g ): mock_g.user = security_manager.find_user("admin") - app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { + current_app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { "REFRESH_TIMEOUT_ON_RETRIEVAL": True } diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 08ae8e16d5f..cf4c96fac6e 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -35,7 +35,8 @@ from sqlalchemy.engine.url import make_url from sqlalchemy.types import DateTime # noqa: F401 import tests.integration_tests.test_app # noqa: F401 -from superset import app, db as metadata_db +from flask import current_app +from superset import db as metadata_db from superset.db_engine_specs.postgres import PostgresEngineSpec # noqa: F401 from superset.common.db_query_status import QueryStatus from superset.models.core import Database @@ -523,11 +524,11 @@ class TestSqlaTableModel(SupersetTestCase): def mutator(*args, **kwargs): return "-- COMMENT\n" + args[0] - app.config["SQL_QUERY_MUTATOR"] = mutator + current_app.config["SQL_QUERY_MUTATOR"] = mutator sql = tbl.get_query_str(query_obj) assert "-- COMMENT" in sql - app.config["SQL_QUERY_MUTATOR"] = None + current_app.config["SQL_QUERY_MUTATOR"] = None @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_mutator_different_params(self): @@ -549,12 +550,12 @@ class TestSqlaTableModel(SupersetTestCase): def mutator(sql, database=None, **kwargs): return "-- COMMENT\n--" + "\n" + str(database) + "\n" + sql - app.config["SQL_QUERY_MUTATOR"] = mutator + current_app.config["SQL_QUERY_MUTATOR"] = mutator mutated_sql = tbl.get_query_str(query_obj) assert "-- COMMENT" in mutated_sql assert tbl.database.name in mutated_sql - app.config["SQL_QUERY_MUTATOR"] = None + current_app.config["SQL_QUERY_MUTATOR"] = None def test_query_with_non_existent_metrics(self): tbl = self.get_table(name="birth_names") diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 661c39eedcb..69aee903cfc 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -23,9 +23,10 @@ from unittest.mock import Mock, patch import numpy as np import pandas as pd import pytest +from flask import current_app from pandas import DateOffset -from superset import app, db +from superset import db from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext @@ -744,7 +745,7 @@ class TestQueryContext(SupersetTestCase): query_object = query_context.queries[0] time_offsets_obj = query_context.processing_time_offsets(df, query_object) sqls = time_offsets_obj["queries"] - row_limit_value = app.config["ROW_LIMIT"] + row_limit_value = current_app.config["ROW_LIMIT"] row_limit_pattern_with_config_value = r"LIMIT " + re.escape( str(row_limit_value) ) diff --git a/tests/integration_tests/reports/api_tests.py b/tests/integration_tests/reports/api_tests.py index 8b595f563cd..58bde34ac72 100644 --- a/tests/integration_tests/reports/api_tests.py +++ b/tests/integration_tests/reports/api_tests.py @@ -1293,7 +1293,7 @@ class TestReportSchedulesApi(SupersetTestCase): "database": example_db.id, } with patch.dict( - "superset.commands.report.base.current_app.config", + "flask.current_app.config", { "ALERT_MINIMUM_INTERVAL": int(timedelta(minutes=2).total_seconds()), "REPORT_MINIMUM_INTERVAL": int(timedelta(minutes=5).total_seconds()), @@ -1340,7 +1340,7 @@ class TestReportSchedulesApi(SupersetTestCase): "database": example_db.id, } with patch.dict( - "superset.commands.report.base.current_app.config", + "flask.current_app.config", { "ALERT_MINIMUM_INTERVAL": int(timedelta(minutes=6).total_seconds()), "REPORT_MINIMUM_INTERVAL": int(timedelta(minutes=8).total_seconds()), @@ -1390,7 +1390,7 @@ class TestReportSchedulesApi(SupersetTestCase): "crontab": "5,10 * * * *", } with patch.dict( - "superset.commands.report.base.current_app.config", + "flask.current_app.config", { "ALERT_MINIMUM_INTERVAL": int(timedelta(minutes=5).total_seconds()), "REPORT_MINIMUM_INTERVAL": int(timedelta(minutes=3).total_seconds()), @@ -1409,7 +1409,7 @@ class TestReportSchedulesApi(SupersetTestCase): assert rv.status_code == 200 with patch.dict( - "superset.commands.report.base.current_app.config", + "flask.current_app.config", { "ALERT_MINIMUM_INTERVAL": 0, "REPORT_MINIMUM_INTERVAL": 0, @@ -1439,7 +1439,7 @@ class TestReportSchedulesApi(SupersetTestCase): "crontab": "5,10 * * * *", } with patch.dict( - "superset.commands.report.base.current_app.config", + "flask.current_app.config", { "ALERT_MINIMUM_INTERVAL": int(timedelta(minutes=6).total_seconds()), "REPORT_MINIMUM_INTERVAL": int(timedelta(minutes=4).total_seconds()), diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index 30a9d174c5b..63ff86c44d4 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -21,7 +21,6 @@ from unittest.mock import call, Mock, patch from uuid import uuid4 import pytest -from flask import current_app from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from flask_sqlalchemy import BaseQuery @@ -1163,7 +1162,9 @@ def test_email_dashboard_report_schedule( screenshot_mock.return_value = SCREENSHOT_FILE with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: AsyncExecuteReportScheduleCommand( TEST_ID, create_report_email_dashboard.id, datetime.utcnow() ).run() @@ -1195,7 +1196,9 @@ def test_email_dashboard_report_schedule_with_tab_anchor( ExecuteReport Command: Test dashboard email report schedule with tab metadata """ with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: # get tabbed dashboard fixture dashboard = db.session.query(Dashboard).all()[1] # build report_schedule @@ -1242,7 +1245,9 @@ def test_email_dashboard_report_schedule_disabled_tabs( ExecuteReport Command: Test dashboard email report schedule with tab metadata """ with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: # get tabbed dashboard fixture dashboard = db.session.query(Dashboard).all()[1] # build report_schedule @@ -1332,7 +1337,9 @@ def test_slack_chart_report_schedule_converts_to_v2( ] with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: AsyncExecuteReportScheduleCommand( TEST_ID, create_report_slack_chart.id, datetime.utcnow() ).run() @@ -1395,7 +1402,9 @@ def test_slack_chart_report_schedule_converts_to_v2_channel_with_hash( ] with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: AsyncExecuteReportScheduleCommand( TEST_ID, report_schedule.id, datetime.utcnow() ).run() @@ -1493,7 +1502,9 @@ def test_slack_chart_report_schedule_v2( screenshot_mock.return_value = SCREENSHOT_FILE with freeze_time("2020-01-01T00:00:00Z"): - with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + with patch( + "superset.extensions.stats_logger_manager.instance.gauge" + ) as statsd_mock: AsyncExecuteReportScheduleCommand( TEST_ID, create_report_slack_chartv2.id, datetime.utcnow() ).run() @@ -1969,17 +1980,18 @@ def test_slack_token_callable_chart_report( "channels": [{"id": channel_id, "name": channel_name}] } - app.config["SLACK_API_TOKEN"] = Mock(return_value="cool_code") - # setup screenshot mock - screenshot_mock.return_value = SCREENSHOT_FILE + slack_token_mock = Mock(return_value="cool_code") + with patch.dict("flask.current_app.config", {"SLACK_API_TOKEN": slack_token_mock}): + # setup screenshot mock + screenshot_mock.return_value = SCREENSHOT_FILE - with freeze_time("2020-01-01T00:00:00Z"): - AsyncExecuteReportScheduleCommand( - TEST_ID, create_report_slack_chart.id, datetime.utcnow() - ).run() - app.config["SLACK_API_TOKEN"].assert_called() - slack_client_mock_class.assert_called_with(token="cool_code", proxy=None) # noqa: S106 - assert_log(ReportState.SUCCESS) + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + TEST_ID, create_report_slack_chart.id, datetime.utcnow() + ).run() + slack_token_mock.assert_called() + slack_client_mock_class.assert_called_with(token="cool_code", proxy=None) # noqa: S106 + assert_log(ReportState.SUCCESS) @pytest.mark.usefixtures("app_context") diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 8a318626294..98c8f169fdb 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -23,7 +23,7 @@ import pytest from flask import g import prison -from superset import db, security_manager, app # noqa: F401 +from superset import db, security_manager from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.security.guest_token import ( GuestTokenResourceType, @@ -625,15 +625,16 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase): def _base_filter(query): return query.filter_by(name="Alpha") - with mock.patch.dict( - "superset.views.filters.current_app.config", - {"EXTRA_RELATED_QUERY_FILTERS": {"role": _base_filter}}, - ): + original_conf = self.app.config.get("EXTRA_RELATED_QUERY_FILTERS", {}).copy() + try: + self.app.config["EXTRA_RELATED_QUERY_FILTERS"] = {"role": _base_filter} rv = self.client.get("/api/v1/rowlevelsecurity/related/roles") # noqa: F541 assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) response_roles = [result["text"] for result in response["result"]] assert response_roles == ["Alpha"] + finally: + self.app.config["EXTRA_RELATED_QUERY_FILTERS"] = original_conf RLS_ALICE_REGEX = re.compile(r"name = 'Alice'") diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 7528fcb620b..dad8a3d47eb 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -31,7 +31,7 @@ from flask import current_app, g from flask_appbuilder.security.sqla.models import Role from superset.daos.datasource import DatasourceDAO # noqa: F401 from superset.models.dashboard import Dashboard -from superset import app, appbuilder, db, security_manager, viz +from superset import appbuilder, db, security_manager, viz from superset.connectors.sqla.models import SqlaTable from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException @@ -2187,7 +2187,7 @@ class TestGuestTokens(SupersetTestCase): def test_create_guest_access_token_callable_audience(self, get_time_mock): now = time.time() get_time_mock.return_value = now - app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code") + self.app.config["GUEST_TOKEN_JWT_AUDIENCE"] = Mock(return_value="cool_code") user = {"username": "test_guest"} resources = [{"some": "resource"}] @@ -2200,7 +2200,7 @@ class TestGuestTokens(SupersetTestCase): algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], audience="cool_code", ) - app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() + self.app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() assert "cool_code" == decoded_token["aud"] assert "guest" == decoded_token["type"] - app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None + self.app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 7d760352996..e3c09fc5091 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -19,9 +19,10 @@ from unittest.mock import Mock, patch import pandas as pd import pytest +from flask import current_app from flask_babel import gettext as __ -from superset import app, db, sql_lab +from superset import db, sql_lab from superset.commands.sql_lab import estimate, export, results from superset.common.db_query_status import QueryStatus from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -89,7 +90,7 @@ class TestQueryEstimationCommand(SupersetTestCase): assert ex_info.value.error.message == __( "The query estimation was killed after %(sqllab_timeout)s seconds. It might " # noqa: E501 "be too complex, or the database might be under heavy load.", - sqllab_timeout=app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"], + sqllab_timeout=current_app.config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"], ) def test_run_success(self) -> None: diff --git a/tests/integration_tests/sql_lab/test_execute_sql_statements.py b/tests/integration_tests/sql_lab/test_execute_sql_statements.py index 41b7a74ca68..bea0d2b1406 100644 --- a/tests/integration_tests/sql_lab/test_execute_sql_statements.py +++ b/tests/integration_tests/sql_lab/test_execute_sql_statements.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset import app +from flask import current_app + from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query @@ -43,12 +44,12 @@ def test_non_async_execute(non_async_example_db: Database, example_query: Query) assert example_query.tracking_url assert "/ui/query.html?" in example_query.tracking_url - app.config["TRACKING_URL_TRANSFORMER"] = lambda url, query: url.replace( + current_app.config["TRACKING_URL_TRANSFORMER"] = lambda url, query: url.replace( "/ui/query.html?", f"/{query.client_id}/" ) assert f"/{example_query.client_id}/" in example_query.tracking_url - app.config["TRACKING_URL_TRANSFORMER"] = lambda url: url + "&foo=bar" + current_app.config["TRACKING_URL_TRANSFORMER"] = lambda url: url + "&foo=bar" assert example_query.tracking_url.endswith("&foo=bar") if non_async_example_db.db_engine_spec.engine_name == "hive": diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py index 6ec30c01591..be5c23ebe55 100644 --- a/tests/integration_tests/test_app.py +++ b/tests/integration_tests/test_app.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: from flask.testing import FlaskClient +# DEPRECATED: Creating global app instance - use app fixture from conftest.py instead superset_config_module = environ.get( "SUPERSET_CONFIG", "tests.integration_tests.superset_test_config" ) diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index a3d4e4b3fff..8949474250f 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -230,7 +230,7 @@ class TestThumbnails(SupersetTestCase): self.login(ALPHA_USERNAME) with ( patch.dict( - "superset.thumbnails.digest.current_app.config", + "flask.current_app.config", { "THUMBNAIL_EXECUTORS": [FixedExecutor(ADMIN_USERNAME)], }, @@ -258,7 +258,7 @@ class TestThumbnails(SupersetTestCase): self.login(username) with ( patch.dict( - "superset.thumbnails.digest.current_app.config", + "flask.current_app.config", { "THUMBNAIL_EXECUTORS": [ExecutorType.CURRENT_USER], }, @@ -308,7 +308,7 @@ class TestThumbnails(SupersetTestCase): self.login(ADMIN_USERNAME) with ( patch.dict( - "superset.thumbnails.digest.current_app.config", + "flask.current_app.config", { "THUMBNAIL_EXECUTORS": [FixedExecutor(ADMIN_USERNAME)], }, @@ -336,7 +336,7 @@ class TestThumbnails(SupersetTestCase): self.login(username) with ( patch.dict( - "superset.thumbnails.digest.current_app.config", + "flask.current_app.config", { "THUMBNAIL_EXECUTORS": [ExecutorType.CURRENT_USER], }, diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 2bab7cdee41..01d82adafbf 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -27,14 +27,14 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_data, # noqa: F401 ) +from flask import current_app, Flask, g # noqa: F401 import pandas as pd import pytest -from flask import Flask, g # noqa: F401 import marshmallow from sqlalchemy.exc import ArgumentError # noqa: F401 import tests.integration_tests.test_app # noqa: F401 -from superset import app, db, security_manager +from superset import db, security_manager from superset.constants import NO_TIME_RANGE from superset.exceptions import CertificateException, SupersetException # noqa: F401 from superset.models.core import Database, Log @@ -289,7 +289,7 @@ class TestUtils(SupersetTestCase): assert slc is None def test_get_form_data_request_args(self) -> None: - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"foo": "bar"})} ): form_data, slc = get_form_data() @@ -297,7 +297,9 @@ class TestUtils(SupersetTestCase): assert slc is None def test_get_form_data_request_form(self) -> None: - with app.test_request_context(data={"form_data": json.dumps({"foo": "bar"})}): + with current_app.test_request_context( + data={"form_data": json.dumps({"foo": "bar"})} + ): form_data, slc = get_form_data() assert form_data == {"foo": "bar"} assert slc is None @@ -305,7 +307,7 @@ class TestUtils(SupersetTestCase): def test_get_form_data_request_form_with_queries(self) -> None: # the CSV export uses for requests, even when sending requests to # /api/v1/chart/data - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps({"queries": [{"url_params": {"foo": "bar"}}]}) } @@ -315,7 +317,7 @@ class TestUtils(SupersetTestCase): assert slc is None def test_get_form_data_request_args_and_form(self) -> None: - with app.test_request_context( + with current_app.test_request_context( data={"form_data": json.dumps({"foo": "bar"})}, query_string={"form_data": json.dumps({"baz": "bar"})}, ): @@ -324,7 +326,7 @@ class TestUtils(SupersetTestCase): assert slc is None def test_get_form_data_globals(self) -> None: - with app.test_request_context(): + with current_app.test_request_context(): g.form_data = {"foo": "bar"} form_data, slc = get_form_data() delattr(g, "form_data") @@ -332,7 +334,7 @@ class TestUtils(SupersetTestCase): assert slc is None def test_get_form_data_corrupted_json(self) -> None: - with app.test_request_context( + with current_app.test_request_context( data={"form_data": "{x: '2324'}"}, query_string={"form_data": '{"baz": "bar"'}, ): diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index b3f4a65da17..2da9ef425d7 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -25,9 +25,10 @@ import pytest import tests.integration_tests.test_app # noqa: F401 import superset.viz as viz -from superset import app +from flask import current_app from superset.exceptions import QueryObjectValidationError, SpatialException from superset.utils.core import DTTM_ALIAS +from tests.conftest import with_config from .base_tests import SupersetTestCase from .utils import load_fixture @@ -162,17 +163,21 @@ class TestBaseViz(SupersetTestCase): datasource.database.cache_timeout = None test_viz = viz.BaseViz(datasource, form_data={}) assert ( - app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + current_app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout ) - data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] - app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None + data_cache_timeout = current_app.config["DATA_CACHE_CONFIG"][ + "CACHE_DEFAULT_TIMEOUT" + ] + current_app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None datasource.database.cache_timeout = None test_viz = viz.BaseViz(datasource, form_data={}) - assert app.config["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout + assert current_app.config["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout # restore DATA_CACHE_CONFIG timeout - app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout + current_app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = ( + data_cache_timeout + ) class TestPairedTTest(SupersetTestCase): @@ -1305,6 +1310,7 @@ class TestDeckGLMultiLayer(SupersetTestCase): {"column": "col1", "op": "==", "val": "value1"} ] + @with_config({"MAPBOX_API_KEY": "test_key"}) @patch("superset.viz.viz_types") @patch("superset.db.session") def test_get_data_with_layer_filtering(self, mock_db_session, mock_viz_types): @@ -1355,25 +1361,25 @@ class TestDeckGLMultiLayer(SupersetTestCase): "deck_slices": [1, 2], # Use integer IDs instead of Mock objects } - with patch("superset.viz.config", {"MAPBOX_API_KEY": "test_key"}): - test_viz = viz.DeckGLMultiLayer(datasource, form_data) + test_viz = viz.DeckGLMultiLayer(datasource, form_data) - test_viz._apply_layer_filtering = Mock( - side_effect=lambda form_data, idx: form_data - ) + test_viz._apply_layer_filtering = Mock( + side_effect=lambda form_data, idx: form_data + ) - result = test_viz.get_data(pd.DataFrame()) + result = test_viz.get_data(pd.DataFrame()) - assert test_viz._apply_layer_filtering.call_count == 2 - test_viz._apply_layer_filtering.assert_any_call(slice_1.form_data, 0) - test_viz._apply_layer_filtering.assert_any_call(slice_2.form_data, 1) + assert test_viz._apply_layer_filtering.call_count == 2 + test_viz._apply_layer_filtering.assert_any_call(slice_1.form_data, 0) + test_viz._apply_layer_filtering.assert_any_call(slice_2.form_data, 1) - assert isinstance(result, dict) - assert "features" in result - assert "mapboxApiKey" in result - assert "slices" in result - assert result["mapboxApiKey"] == "test_key" + assert isinstance(result, dict) + assert "features" in result + assert "mapboxApiKey" in result + assert "slices" in result + assert result["mapboxApiKey"] == "test_key" + @with_config({"MAPBOX_API_KEY": "test_key"}) @patch("superset.viz.viz_types") @patch("superset.db.session") def test_get_data_filters_none_data_slices(self, mock_db_session, mock_viz_types): @@ -1404,27 +1410,26 @@ class TestDeckGLMultiLayer(SupersetTestCase): form_data = {"deck_slices": [1, 2]} # Use integer IDs instead of Mock objects - with patch("superset.viz.config", {"MAPBOX_API_KEY": "test_key"}): - test_viz = viz.DeckGLMultiLayer(datasource, form_data) - result = test_viz.get_data(pd.DataFrame()) + test_viz = viz.DeckGLMultiLayer(datasource, form_data) + result = test_viz.get_data(pd.DataFrame()) - assert isinstance(result, dict) - assert len(result["slices"]) == 1 - assert result["slices"][0] == slice_1.data + assert isinstance(result, dict) + assert len(result["slices"]) == 1 + assert result["slices"][0] == slice_1.data + @with_config({"MAPBOX_API_KEY": "test_key"}) def test_get_data_empty_deck_slices(self): """Test get_data method with empty deck_slices.""" datasource = self.get_datasource_mock() form_data = {"deck_slices": []} - with patch("superset.viz.config", {"MAPBOX_API_KEY": "test_key"}): - test_viz = viz.DeckGLMultiLayer(datasource, form_data) - result = test_viz.get_data(pd.DataFrame()) + test_viz = viz.DeckGLMultiLayer(datasource, form_data) + result = test_viz.get_data(pd.DataFrame()) - assert isinstance(result, dict) - assert result["features"] == {} - assert result["slices"] == [] - assert result["mapboxApiKey"] == "test_key" + assert isinstance(result, dict) + assert result["features"] == {} + assert result["slices"] == [] + assert result["mapboxApiKey"] == "test_key" class TestTimeSeriesViz(SupersetTestCase): diff --git a/tests/unit_tests/charts/test_schemas.py b/tests/unit_tests/charts/test_schemas.py new file mode 100644 index 00000000000..5466a0deadd --- /dev/null +++ b/tests/unit_tests/charts/test_schemas.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from flask import current_app +from marshmallow import ValidationError + +from superset.charts.schemas import ( + ChartDataProphetOptionsSchema, + ChartDataQueryObjectSchema, + get_time_grain_choices, +) + + +def test_get_time_grain_choices(app_context: None) -> None: + """Test that get_time_grain_choices returns values with config addons""" + # Save original config + original_addons = current_app.config.get("TIME_GRAIN_ADDONS", {}) + + try: + # Test with no addons + current_app.config["TIME_GRAIN_ADDONS"] = {} + choices = get_time_grain_choices() + # Should have at least the basic time grains + assert "P1D" in choices + assert "P1W" in choices + assert "P1M" in choices + assert "P1Y" in choices + + # Test with addons + current_app.config["TIME_GRAIN_ADDONS"] = { + "PT5M": "5 minutes", + "P2W": "2 weeks", + } + choices = get_time_grain_choices() + assert "PT5M" in choices + assert "P2W" in choices + assert "P1D" in choices # Still has built-in choices + finally: + # Restore original config + current_app.config["TIME_GRAIN_ADDONS"] = original_addons + + +def test_chart_data_prophet_options_schema_time_grain_validation( + app_context: None, +) -> None: + """Test that ChartDataProphetOptionsSchema validates time_grain choices""" + schema = ChartDataProphetOptionsSchema() + + # Valid time grain should pass + valid_data = { + "time_grain": "P1D", + "periods": 7, + "confidence_interval": 0.8, + } + result = schema.load(valid_data) + assert result["time_grain"] == "P1D" + + # Invalid time grain should fail + invalid_data = { + "time_grain": "invalid_grain", + "periods": 7, + "confidence_interval": 0.8, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data) + assert "time_grain" in exc_info.value.messages + assert "Must be one of" in str(exc_info.value.messages["time_grain"]) + + # Empty time grain should fail (required field) + missing_data = { + "periods": 7, + "confidence_interval": 0.8, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(missing_data) + assert "time_grain" in exc_info.value.messages + + +def test_chart_data_query_object_schema_time_grain_sqla_validation( + app_context: None, +) -> None: + """Test that ChartDataQueryObjectSchema validates time_grain_sqla in extras""" + schema = ChartDataQueryObjectSchema() + + # Valid time grain should pass (time_grain_sqla is in extras) + valid_data = { + "datasource": {"type": "table", "id": 1}, + "metrics": ["count"], + "extras": { + "time_grain_sqla": "P1W", + }, + } + result = schema.load(valid_data) + assert "extras" in result + assert result["extras"]["time_grain_sqla"] == "P1W" + + # Invalid time grain should fail + invalid_data = { + "datasource": {"type": "table", "id": 1}, + "metrics": ["count"], + "extras": { + "time_grain_sqla": "not_a_grain", + }, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(invalid_data) + assert "extras" in exc_info.value.messages + assert "time_grain_sqla" in exc_info.value.messages["extras"] + assert "Must be one of" in str(exc_info.value.messages["extras"]["time_grain_sqla"]) + + # None should be allowed (allow_none=True) + none_data = { + "datasource": {"type": "table", "id": 1}, + "metrics": ["count"], + "extras": { + "time_grain_sqla": None, + }, + } + result = schema.load(none_data) + assert result["extras"]["time_grain_sqla"] is None + + +@pytest.mark.parametrize( + "app", + [{"TIME_GRAIN_ADDONS": {"PT10M": "10 minutes"}}], + indirect=True, +) +def test_time_grain_validation_with_config_addons(app_context: None) -> None: + """Test that validation includes TIME_GRAIN_ADDONS from config""" + schema = ChartDataProphetOptionsSchema() + + # Custom time grain should now be valid + custom_data = { + "time_grain": "PT10M", + "periods": 5, + "confidence_interval": 0.9, + } + result = schema.load(custom_data) + assert result["time_grain"] == "PT10M" diff --git a/tests/unit_tests/commands/report/base_test.py b/tests/unit_tests/commands/report/base_test.py index 61d233535ee..ec8c6973406 100644 --- a/tests/unit_tests/commands/report/base_test.py +++ b/tests/unit_tests/commands/report/base_test.py @@ -76,13 +76,11 @@ def app_custom_config( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - with patch( - "superset.commands.report.base.current_app.config" - ) as mock_config: - mock_config.get.side_effect = lambda key, default=0: { - "ALERT_MINIMUM_INTERVAL": alert_minimum_interval, - "REPORT_MINIMUM_INTERVAL": report_minimum_interval, - }.get(key, default) + config_overrides = { + "ALERT_MINIMUM_INTERVAL": alert_minimum_interval, + "REPORT_MINIMUM_INTERVAL": report_minimum_interval, + } + with patch("flask.current_app.config", config_overrides): return func(*args, **kwargs) return wrapper diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py index 837c53ec074..571b789108f 100644 --- a/tests/unit_tests/config_test.py +++ b/tests/unit_tests/config_test.py @@ -24,6 +24,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session from superset import db +from tests.conftest import with_config if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -103,23 +104,21 @@ def test_table(session: Session) -> "SqlaTable": ) +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults={ + "main_dttm_col": "event_time", + "dttm_columns": {"ds": {}, "event_time": {}}, + }, + ) + } +) def test_main_dttm_col(mocker: MockerFixture, test_table: "SqlaTable") -> None: """ Test the ``SQLA_TABLE_MUTATOR`` config. """ - dttm_defaults = { - "main_dttm_col": "event_time", - "dttm_columns": {"ds": {}, "event_time": {}}, - } - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=dttm_defaults, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ @@ -134,6 +133,16 @@ def test_main_dttm_col(mocker: MockerFixture, test_table: "SqlaTable") -> None: assert test_table.main_dttm_col == "event_time" +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults={ + "main_dttm_col": "nonexistent", + }, + ) + } +) def test_main_dttm_col_nonexistent( mocker: MockerFixture, test_table: "SqlaTable", @@ -141,18 +150,6 @@ def test_main_dttm_col_nonexistent( """ Test the ``SQLA_TABLE_MUTATOR`` config when main datetime column doesn't exist. """ - dttm_defaults = { - "main_dttm_col": "nonexistent", - } - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=dttm_defaults, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ @@ -168,6 +165,16 @@ def test_main_dttm_col_nonexistent( assert test_table.main_dttm_col == "ds" +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults={ + "main_dttm_col": "id", + }, + ) + } +) def test_main_dttm_col_nondttm( mocker: MockerFixture, test_table: "SqlaTable", @@ -175,18 +182,6 @@ def test_main_dttm_col_nondttm( """ Test the ``SQLA_TABLE_MUTATOR`` config when main datetime column has wrong type. """ - dttm_defaults = { - "main_dttm_col": "id", - } - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=dttm_defaults, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ @@ -202,6 +197,19 @@ def test_main_dttm_col_nondttm( assert test_table.main_dttm_col == "ds" +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults={ + "dttm_columns": { + "id": {"python_date_format": "epoch_ms"}, + "dttm": {"python_date_format": "epoch_s"}, + }, + }, + ) + } +) def test_python_date_format_by_column_name( mocker: MockerFixture, test_table: "SqlaTable", @@ -209,21 +217,6 @@ def test_python_date_format_by_column_name( """ Test the ``SQLA_TABLE_MUTATOR`` setting for "python_date_format". """ - table_defaults = { - "dttm_columns": { - "id": {"python_date_format": "epoch_ms"}, - "dttm": {"python_date_format": "epoch_s"}, - }, - } - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=table_defaults, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ @@ -243,6 +236,19 @@ def test_python_date_format_by_column_name( assert dttm_col.python_date_format == "epoch_s" +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults={ + "dttm_columns": { + "dttm": {"expression": "CAST(dttm as INTEGER)"}, + "duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"}, + }, + }, + ) + } +) def test_expression_by_column_name( mocker: MockerFixture, test_table: "SqlaTable", @@ -250,21 +256,6 @@ def test_expression_by_column_name( """ Test the ``SQLA_TABLE_MUTATOR`` setting for expression. """ - table_defaults = { - "dttm_columns": { - "dttm": {"expression": "CAST(dttm as INTEGER)"}, - "duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"}, - }, - } - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=table_defaults, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ @@ -286,6 +277,14 @@ def test_expression_by_column_name( assert duration_ms_col.expression == "CAST(duration_ms as DOUBLE)" +@with_config( + { + "SQLA_TABLE_MUTATOR": partial( + apply_dttm_defaults, + dttm_defaults=FULL_DTTM_DEFAULTS_EXAMPLE, + ) + } +) def test_full_setting( mocker: MockerFixture, test_table: "SqlaTable", @@ -293,15 +292,6 @@ def test_full_setting( """ Test the ``SQLA_TABLE_MUTATOR`` with full settings. """ - mocker.patch( - "superset.connectors.sqla.models.config", - new={ - "SQLA_TABLE_MUTATOR": partial( - apply_dttm_defaults, - dttm_defaults=FULL_DTTM_DEFAULTS_EXAMPLE, - ) - }, - ) mocker.patch( "superset.connectors.sqla.models.get_physical_table_metadata", return_value=[ diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index ae0529e5bcf..3712a9e96cc 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -721,19 +721,24 @@ def test_apply_dynamic_database_filter( # Ensure that the filter has not been called because it's not in our config assert base_filter_mock.call_count == 0 - original_config = current_app.config.copy() - original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock} + # Temporarily update the config + original_filters = current_app.config.get("EXTRA_DYNAMIC_QUERY_FILTERS", {}) + current_app.config["EXTRA_DYNAMIC_QUERY_FILTERS"] = { + "databases": base_filter_mock + } + try: + # Get filtered list + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names - mocker.patch("superset.views.filters.current_app.config", new=original_config) - # Get filtered list - response_databases = DatabaseDAO.find_all() - assert response_databases - expected_db_names = ["second-database"] - actual_db_names = [db.database_name for db in response_databases] - assert actual_db_names == expected_db_names - - # Ensure that the filter has been called once - assert base_filter_mock.call_count == 1 + # Ensure that the filter has been called once + assert base_filter_mock.call_count == 1 + finally: + # Restore original config + current_app.config["EXTRA_DYNAMIC_QUERY_FILTERS"] = original_filters def test_oauth2_happy_path( diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index e052037d7fc..6e771d47128 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -19,6 +19,7 @@ import copy import pytest +from flask import current_app from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session @@ -97,12 +98,12 @@ def test_import_database_sqlite_invalid( """ Test importing a database. """ - from superset import app, security_manager + from superset import security_manager from superset.commands.database.importers.v1.utils import import_database from superset.models.core import Database from tests.integration_tests.fixtures.importexport import database_config_sqlite - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True + current_app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True mocker.patch.object(security_manager, "can_access", return_value=True) engine = db.session.get_bind() @@ -116,7 +117,7 @@ def test_import_database_sqlite_invalid( == "SQLiteDialect_pysqlite cannot be used as a data source for security reasons." # noqa: E501 ) # restore app config - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True + current_app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True def test_import_database_managed_externally( diff --git a/tests/unit_tests/databases/filters_test.py b/tests/unit_tests/databases/filters_test.py index 9e1b1c0aebc..253798d9cdf 100644 --- a/tests/unit_tests/databases/filters_test.py +++ b/tests/unit_tests/databases/filters_test.py @@ -60,8 +60,7 @@ def test_database_filter_full_db_access(mocker: MockerFixture) -> None: """ from superset.models.core import Database - current_app = mocker.patch("superset.databases.filters.current_app") - current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False} + mocker.patch("flask.current_app.config", {"EXTRA_DYNAMIC_QUERY_FILTERS": False}) mocker.patch.object(security_manager, "can_access_all_databases", return_value=True) engine = create_engine("sqlite://") @@ -81,8 +80,7 @@ def test_database_filter(mocker: MockerFixture) -> None: """ from superset.models.core import Database - current_app = mocker.patch("superset.databases.filters.current_app") - current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False} + mocker.patch("flask.current_app.config", {"EXTRA_DYNAMIC_QUERY_FILTERS": False}) mocker.patch.object( security_manager, "can_access_all_databases", diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 44ddee0aa91..ec9689dd17c 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -58,7 +58,7 @@ def test_validate_db_uri(mocker: MockerFixture) -> None: raise ValueError("Invalid URI") mocker.patch( - "superset.db_engine_specs.base.current_app.config", + "flask.current_app.config", {"DB_SQLA_URI_VALIDATOR": mock_validate}, ) diff --git a/tests/unit_tests/db_engine_specs/test_duckdb.py b/tests/unit_tests/db_engine_specs/test_duckdb.py index 9bbd38e5d59..3b92a4d1e73 100644 --- a/tests/unit_tests/db_engine_specs/test_duckdb.py +++ b/tests/unit_tests/db_engine_specs/test_duckdb.py @@ -21,8 +21,8 @@ from typing import Optional import pytest from pytest_mock import MockerFixture -from superset.config import VERSION_STRING from superset.utils import json +from tests.conftest import with_config from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -45,6 +45,7 @@ def test_convert_dttm( assert_convert_dttm(spec, target_type, expected_result, dttm) +@with_config({"VERSION_STRING": "1.0.0"}) def test_get_extra_params(mocker: MockerFixture) -> None: """ Test the ``get_extra_params`` method. @@ -56,9 +57,7 @@ def test_get_extra_params(mocker: MockerFixture) -> None: database.extra = {} assert DuckDBEngineSpec.get_extra_params(database) == { "engine_params": { - "connect_args": { - "config": {"custom_user_agent": f"apache-superset/{VERSION_STRING}"} - } + "connect_args": {"config": {"custom_user_agent": "apache-superset/1.0.0"}} } } @@ -68,9 +67,7 @@ def test_get_extra_params(mocker: MockerFixture) -> None: assert DuckDBEngineSpec.get_extra_params(database) == { "engine_params": { "connect_args": { - "config": { - "custom_user_agent": f"apache-superset/{VERSION_STRING} my-app" - } + "config": {"custom_user_agent": "apache-superset/1.0.0 my-app"} } } } diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 517bfe8328f..25224652b04 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -543,7 +543,7 @@ def test_is_oauth2_enabled_no_config(mocker: MockerFixture) -> None: from superset.db_engine_specs.gsheets import GSheetsEngineSpec mocker.patch( - "superset.db_engine_specs.base.current_app.config", + "flask.current_app.config", new={"DATABASE_OAUTH2_CLIENTS": {}}, ) @@ -557,7 +557,7 @@ def test_is_oauth2_enabled_config(mocker: MockerFixture) -> None: from superset.db_engine_specs.gsheets import GSheetsEngineSpec mocker.patch( - "superset.db_engine_specs.base.current_app.config", + "flask.current_app.config", new={ "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py index a5499f0711c..68329f18ec7 100644 --- a/tests/unit_tests/extensions/test_sqlalchemy.py +++ b/tests/unit_tests/extensions/test_sqlalchemy.py @@ -29,6 +29,7 @@ from sqlalchemy.orm.session import Session from superset import db from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException +from tests.conftest import with_config from tests.unit_tests.conftest import with_feature_flags if TYPE_CHECKING: @@ -54,7 +55,8 @@ def database1(session: Session) -> Iterator["Database"]: db.session.delete(database) db.session.commit() - os.unlink("database1.db") + if os.path.exists("database1.db"): + os.unlink("database1.db") @pytest.fixture @@ -87,7 +89,8 @@ def database2(session: Session) -> Iterator["Database"]: db.session.delete(database) db.session.commit() - os.unlink("database2.db") + if os.path.exists("database2.db"): + os.unlink("database2.db") @pytest.fixture @@ -109,7 +112,13 @@ def test_superset(mocker: MockerFixture, app_context: None, table1: None) -> Non """ Simple test querying a table. """ - mocker.patch("superset.extensions.metadb.security_manager") + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + + mocker.patch("superset.extensions.metadb.security_manager") + except ImportError: + pytest.skip("metadb dependencies not available") engine = create_engine("superset://") conn = engine.connect() @@ -117,20 +126,26 @@ def test_superset(mocker: MockerFixture, app_context: None, table1: None) -> Non assert list(results) == [(1, 10), (2, 20)] +@with_config( + { + "DB_SQLA_URI_VALIDATOR": None, + "SUPERSET_META_DB_LIMIT": 1, + "DATABASE_OAUTH2_CLIENTS": {}, + "SQLALCHEMY_CUSTOM_PASSWORD_STORE": None, + } +) @with_feature_flags(ENABLE_SUPERSET_META_DB=True) def test_superset_limit(mocker: MockerFixture, app_context: None, table1: None) -> None: """ Simple that limit is applied when querying a table. """ - mocker.patch( - "superset.extensions.metadb.current_app.config", - { - "DB_SQLA_URI_VALIDATOR": None, - "SUPERSET_META_DB_LIMIT": 1, - "DATABASE_OAUTH2_CLIENTS": {}, - }, - ) - mocker.patch("superset.extensions.metadb.security_manager") + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + + mocker.patch("superset.extensions.metadb.security_manager") + except ImportError: + pytest.skip("metadb dependencies not available") engine = create_engine("superset://") conn = engine.connect() @@ -148,7 +163,13 @@ def test_superset_joins( """ A test joining across databases. """ - mocker.patch("superset.extensions.metadb.security_manager") + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + + mocker.patch("superset.extensions.metadb.security_manager") + except ImportError: + pytest.skip("metadb dependencies not available") engine = create_engine("superset://") conn = engine.connect() @@ -175,7 +196,13 @@ def test_dml( Test that we can update/delete data, only if DML is enabled. """ - mocker.patch("superset.extensions.metadb.security_manager") + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + + mocker.patch("superset.extensions.metadb.security_manager") + except ImportError: + pytest.skip("metadb dependencies not available") engine = create_engine("superset://") conn = engine.connect() @@ -207,6 +234,12 @@ def test_security_manager( """ Test that we use the security manager to check for permissions. """ + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + except ImportError: + pytest.skip("metadb dependencies not available") + security_manager = mocker.MagicMock() mocker.patch( "superset.extensions.metadb.security_manager", @@ -238,7 +271,13 @@ def test_allowed_dbs(mocker: MockerFixture, app_context: None, table1: None) -> """ Test that DBs can be restricted. """ - mocker.patch("superset.extensions.metadb.security_manager") + # Skip this test if metadb dependencies are not available + try: + import superset.extensions.metadb # noqa: F401 + + mocker.patch("superset.extensions.metadb.security_manager") + except ImportError: + pytest.skip("metadb dependencies not available") engine = create_engine("superset://", allowed_dbs=["database1"]) conn = engine.connect() diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index 391d8c062fb..dfadbd00d7c 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -21,6 +21,7 @@ from datetime import datetime from typing import Any import pytest +from flask import current_app from flask_appbuilder.security.sqla.models import Role from freezegun import freeze_time from jinja2 import DebugUndefined @@ -29,7 +30,6 @@ from pytest_mock import MockerFixture from sqlalchemy.dialects import mysql from sqlalchemy.dialects.postgresql import dialect -from superset import app from superset.commands.dataset.exceptions import DatasetNotFoundError from superset.connectors.sqla.models import ( RowLevelSecurityFilter, @@ -58,7 +58,7 @@ def test_filter_values_adhoc_filters() -> None: """ Test the ``filter_values`` macro with ``adhoc_filters``. """ - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -79,7 +79,7 @@ def test_filter_values_adhoc_filters() -> None: assert cache.filter_values("name") == ["foo"] assert cache.applied_filters == ["name"] - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -105,7 +105,7 @@ def test_filter_values_extra_filters() -> None: """ Test the ``filter_values`` macro with ``extra_filters``. """ - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( {"extra_filters": [{"col": "name", "op": "in", "val": "foo"}]} @@ -147,7 +147,7 @@ def test_get_filters_adhoc_filters() -> None: """ Test the ``get_filters`` macro. """ - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -172,7 +172,7 @@ def test_get_filters_adhoc_filters() -> None: assert cache.removed_filters == [] assert cache.applied_filters == ["name"] - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -195,7 +195,7 @@ def test_get_filters_adhoc_filters() -> None: ] assert cache.removed_filters == [] - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -225,7 +225,7 @@ def test_get_filters_is_null_operator() -> None: Test the ``get_filters`` macro with a IS_NULL operator, which doesn't have a comparator """ - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -263,7 +263,7 @@ def test_url_param_query() -> None: """ Test the ``url_param`` macro. """ - with app.test_request_context(query_string={"foo": "bar"}): + with current_app.test_request_context(query_string={"foo": "bar"}): cache = ExtraCache() assert cache.url_param("foo") == "bar" @@ -272,7 +272,7 @@ def test_url_param_default() -> None: """ Test the ``url_param`` macro with a default value. """ - with app.test_request_context(): + with current_app.test_request_context(): cache = ExtraCache() assert cache.url_param("foo", "bar") == "bar" @@ -281,7 +281,7 @@ def test_url_param_no_default() -> None: """ Test the ``url_param`` macro without a match. """ - with app.test_request_context(): + with current_app.test_request_context(): cache = ExtraCache() assert cache.url_param("foo") is None @@ -290,7 +290,7 @@ def test_url_param_form_data() -> None: """ Test the ``url_param`` with ``url_params`` in ``form_data``. """ - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})} ): cache = ExtraCache() @@ -302,7 +302,7 @@ def test_url_param_escaped_form_data() -> None: Test the ``url_param`` with ``url_params`` in ``form_data`` returning an escaped value with a quote. """ - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): cache = ExtraCache(dialect=dialect()) @@ -313,7 +313,7 @@ def test_url_param_escaped_default_form_data() -> None: """ Test the ``url_param`` with default value containing an escaped quote. """ - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): cache = ExtraCache(dialect=dialect()) @@ -325,7 +325,7 @@ def test_url_param_unescaped_form_data() -> None: Test the ``url_param`` with ``url_params`` in ``form_data`` returning an un-escaped value with a quote. """ - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): cache = ExtraCache(dialect=dialect()) @@ -336,7 +336,7 @@ def test_url_param_unescaped_default_form_data() -> None: """ Test the ``url_param`` with default value containing an un-escaped quote. """ - with app.test_request_context( + with current_app.test_request_context( query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} ): cache = ExtraCache(dialect=dialect()) @@ -895,7 +895,7 @@ def test_metric_macro_no_dataset_id_no_context(mocker: MockerFixture) -> None: mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {} env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context(): + with current_app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( @@ -916,7 +916,7 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( mock_g.form_data = {"queries": []} env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -963,7 +963,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -990,7 +990,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( } ], } - with app.test_request_context(): + with current_app.test_request_context(): assert metric_macro(env, {}, "macro_key") == "COUNT(*)" @@ -1006,7 +1006,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -1037,7 +1037,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( } ], } - with app.test_request_context(): + with current_app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( @@ -1072,7 +1072,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -1099,7 +1099,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( } ], } - with app.test_request_context(): + with current_app.test_request_context(): assert metric_macro(env, {}, "macro_key") == "COUNT(*)" @@ -1115,7 +1115,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -1146,7 +1146,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( } ], } - with app.test_request_context(): + with current_app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( @@ -1168,7 +1168,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -1199,7 +1199,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( } ], } - with app.test_request_context(): + with current_app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( @@ -1230,7 +1230,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( # Getting the data from the request context env = SandboxedEnvironment(undefined=DebugUndefined) - with app.test_request_context( + with current_app.test_request_context( data={ "form_data": json.dumps( { @@ -1248,7 +1248,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( "datasource": "1__table", } - with app.test_request_context(): + with current_app.test_request_context(): assert metric_macro(env, {}, "macro_key") == "COUNT(*)" @@ -1423,7 +1423,7 @@ def test_get_time_filter( with ( freeze_time("2024-09-03"), - app.test_request_context( + current_app.test_request_context( json={"queries": queries}, ), ): diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py b/tests/unit_tests/migrations/shared/catalogs_test.py index dd810cef392..5c27af139f6 100644 --- a/tests/unit_tests/migrations/shared/catalogs_test.py +++ b/tests/unit_tests/migrations/shared/catalogs_test.py @@ -18,10 +18,10 @@ import json # noqa: TID251 import pytest +from flask import current_app from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session -from superset import app from superset.migrations.shared.catalogs import ( downgrade_catalog_perms, upgrade_catalog_perms, @@ -568,8 +568,8 @@ def test_upgrade_catalog_perms_simplified_migration( ("[my_db].[public]",), ] - with app.test_request_context(): - app.config["CATALOGS_SIMPLIFIED_MIGRATION"] = True + with current_app.test_request_context(): + current_app.config["CATALOGS_SIMPLIFIED_MIGRATION"] = True upgrade_catalog_perms() session.commit() diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 9742390b2d8..78915ad2294 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -19,6 +19,7 @@ from datetime import datetime import pytest +from flask import current_app from pytest_mock import MockerFixture from sqlalchemy import ( Column, @@ -674,14 +675,16 @@ def test_get_schema_access_for_file_upload() -> None: assert database.get_schema_access_for_file_upload() == {"public"} -def test_engine_context_manager(mocker: MockerFixture) -> None: +def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None: """ Test the engine context manager. """ - engine_context_manager = mocker.MagicMock() - mocker.patch( - "superset.models.core.config", - new={"ENGINE_CONTEXT_MANAGER": engine_context_manager}, + from unittest.mock import MagicMock + + engine_context_manager = MagicMock() + mocker.patch.dict( + current_app.config, + {"ENGINE_CONTEXT_MANAGER": engine_context_manager}, ) _get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine") diff --git a/tests/unit_tests/reports/schemas_test.py b/tests/unit_tests/reports/schemas_test.py index 9a0e9d0c45d..cf72149070b 100644 --- a/tests/unit_tests/reports/schemas_test.py +++ b/tests/unit_tests/reports/schemas_test.py @@ -26,11 +26,13 @@ def test_report_post_schema_custom_width_validation(mocker: MockerFixture) -> No """ Test the custom width validation. """ - current_app = mocker.patch("superset.reports.schemas.current_app") - current_app.config = { - "ALERT_REPORTS_MIN_CUSTOM_SCREENSHOT_WIDTH": 100, - "ALERT_REPORTS_MAX_CUSTOM_SCREENSHOT_WIDTH": 200, - } + mocker.patch( + "flask.current_app.config", + { + "ALERT_REPORTS_MIN_CUSTOM_SCREENSHOT_WIDTH": 100, + "ALERT_REPORTS_MAX_CUSTOM_SCREENSHOT_WIDTH": 200, + }, + ) schema = ReportSchedulePostSchema() diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 67125833255..2f165df009f 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -17,7 +17,7 @@ # pylint: disable=import-outside-toplevel, invalid-name, unused-argument, too-many-locals import json # noqa: TID251 -from unittest import mock +from unittest.mock import MagicMock from uuid import UUID import pytest @@ -36,6 +36,7 @@ from superset.sql_lab import ( get_sql_results, ) from superset.utils.rls import apply_rls, get_predicates_for_table +from tests.conftest import with_config from tests.unit_tests.models.core_test import oauth2_client_info @@ -65,11 +66,20 @@ def test_execute_query(mocker: MockerFixture, app: None) -> None: SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) -@mock.patch.dict( - "superset.sql_lab.config", - {"SQLLAB_PAYLOAD_MAX_MB": 50}, # Set the desired config value for testing +@with_config( + { + "SQLLAB_PAYLOAD_MAX_MB": 50, + "DISALLOWED_SQL_FUNCTIONS": {}, + "SQLLAB_CTAS_NO_LIMIT": False, + "SQL_MAX_ROW": 100000, + "QUERY_LOGGER": None, + "TROUBLESHOOTING_LINK": None, + "STATS_LOGGER": MagicMock(), + } ) -def test_execute_sql_statement_exceeds_payload_limit(mocker: MockerFixture) -> None: +def test_execute_sql_statement_exceeds_payload_limit( + mocker: MockerFixture, app +) -> None: """ Test for `execute_sql_statements` when the result payload size exceeds the limit. """ @@ -116,11 +126,18 @@ def test_execute_sql_statement_exceeds_payload_limit(mocker: MockerFixture) -> N ) -@mock.patch.dict( - "superset.sql_lab.config", - {"SQLLAB_PAYLOAD_MAX_MB": 50}, # Set the desired config value for testing +@with_config( + { + "SQLLAB_PAYLOAD_MAX_MB": 50, + "DISALLOWED_SQL_FUNCTIONS": {}, + "SQLLAB_CTAS_NO_LIMIT": False, + "SQL_MAX_ROW": 100000, + "QUERY_LOGGER": None, + "TROUBLESHOOTING_LINK": None, + "STATS_LOGGER": MagicMock(), + } ) -def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> None: +def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture, app) -> None: """ Test for `execute_sql_statements` when the result payload size is within the limit, and check if the flow executes smoothly without raising any exceptions. diff --git a/tests/unit_tests/themes/commands_test.py b/tests/unit_tests/themes/commands_test.py index a87c6770d9a..b81a6c22f7b 100644 --- a/tests/unit_tests/themes/commands_test.py +++ b/tests/unit_tests/themes/commands_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest @@ -26,6 +26,7 @@ from superset.commands.theme.exceptions import ( from superset.commands.theme.seed import SeedSystemThemesCommand from superset.commands.theme.update import UpdateThemeCommand from superset.models.core import Theme +from tests.conftest import with_config class TestUpdateThemeCommand: @@ -95,34 +96,30 @@ class TestUpdateThemeCommand: class TestSeedSystemThemesCommand: """Unit tests for SeedSystemThemesCommand""" - @patch("superset.commands.theme.seed.current_app") - def test_run_no_themes_configured(self, mock_current_app): + @with_config( + { + "THEME_DEFAULT": None, + "THEME_DARK": None, + } + ) + def test_run_no_themes_configured(self, app): """Test run when no themes are configured""" # Arrange - mock_current_app.config = MagicMock() - mock_current_app.config.get.side_effect = lambda key: None command = SeedSystemThemesCommand() # Act command.run() # Should complete without error - # Assert - mock_current_app.config.get.assert_any_call("THEME_DEFAULT") - mock_current_app.config.get.assert_any_call("THEME_DARK") - - @patch("superset.commands.theme.seed.current_app") + @with_config( + { + "THEME_DEFAULT": {"algorithm": "default", "token": {}}, + "THEME_DARK": None, + } + ) @patch("superset.commands.theme.seed.db") - def test_run_with_theme_default_only(self, mock_db, mock_current_app): + def test_run_with_theme_default_only(self, mock_db, app): """Test run when only THEME_DEFAULT is configured""" # Arrange - default_theme = {"algorithm": "default", "token": {}} - - def get_config(key): - return {"THEME_DEFAULT": default_theme, "THEME_DARK": None}.get(key) - - mock_current_app.config = MagicMock() - mock_current_app.config.get = Mock(side_effect=get_config) - mock_session = Mock() mock_db.session = mock_session mock_session.query.return_value.filter.return_value.first.return_value = None @@ -136,19 +133,16 @@ class TestSeedSystemThemesCommand: mock_session.add.assert_called_once() # Note: commit is handled by @transaction() decorator, not directly called - @patch("superset.commands.theme.seed.current_app") + @with_config( + { + "THEME_DEFAULT": {"algorithm": "default", "token": {}}, + "THEME_DARK": None, + } + ) @patch("superset.commands.theme.seed.db") - def test_run_update_existing_theme(self, mock_db, mock_current_app): + def test_run_update_existing_theme(self, mock_db, app): """Test run when theme already exists and needs updating""" # Arrange - default_theme = {"algorithm": "default", "token": {}} - - def get_config(key): - return {"THEME_DEFAULT": default_theme, "THEME_DARK": None}.get(key) - - mock_current_app.config = MagicMock() - mock_current_app.config.get = Mock(side_effect=get_config) - # Mock existing theme mock_existing_theme = Mock(spec=Theme) mock_existing_theme.json_data = '{"old": "data"}' @@ -169,20 +163,17 @@ class TestSeedSystemThemesCommand: # Note: commit is handled by @transaction() decorator, not directly called mock_session.add.assert_not_called() # Should not add new theme - @patch("superset.commands.theme.seed.current_app") + @with_config( + { + "THEME_DEFAULT": {"algorithm": "default", "token": {}}, + "THEME_DARK": None, + } + ) @patch("superset.commands.theme.seed.db") @patch("superset.commands.theme.seed.logger") - def test_run_handles_database_error(self, mock_logger, mock_db, mock_current_app): + def test_run_handles_database_error(self, mock_logger, mock_db, app): """Test run handles database errors gracefully""" # Arrange - default_theme = {"algorithm": "default", "token": {}} - - def get_config(key): - return {"THEME_DEFAULT": default_theme, "THEME_DARK": None}.get(key) - - mock_current_app.config = MagicMock() - mock_current_app.config.get = Mock(side_effect=get_config) - mock_session = Mock() mock_db.session = mock_session mock_session.query.side_effect = Exception("Database error") @@ -193,20 +184,16 @@ class TestSeedSystemThemesCommand: with pytest.raises(Exception, match="Database error"): command.run() # Should raise exception due to @transaction() decorator - @patch("superset.commands.theme.seed.current_app") + @with_config( + { + "THEME_DEFAULT": {"algorithm": "default", "token": {}}, + "THEME_DARK": {"algorithm": "dark", "token": {}}, + } + ) @patch("superset.commands.theme.seed.db") - def test_run_with_both_themes(self, mock_db, mock_current_app): + def test_run_with_both_themes(self, mock_db, app): """Test run when both THEME_DEFAULT and THEME_DARK are configured""" # Arrange - default_theme = {"algorithm": "default", "token": {}} - dark_theme = {"algorithm": "dark", "token": {}} - - def get_config(key): - return {"THEME_DEFAULT": default_theme, "THEME_DARK": dark_theme}.get(key) - - mock_current_app.config = MagicMock() - mock_current_app.config.get = Mock(side_effect=get_config) - mock_session = Mock() mock_db.session = mock_session mock_session.query.return_value.filter.return_value.first.return_value = None diff --git a/tests/unit_tests/thumbnails/test_digest.py b/tests/unit_tests/thumbnails/test_digest.py index aa5d8d08aea..301e033e79c 100644 --- a/tests/unit_tests/thumbnails/test_digest.py +++ b/tests/unit_tests/thumbnails/test_digest.py @@ -21,6 +21,7 @@ from typing import Any, TYPE_CHECKING from unittest.mock import MagicMock, patch, PropertyMock import pytest +from flask import current_app from flask_appbuilder.security.sqla.models import User from superset.connectors.sqla.models import BaseDatasource, SqlaTable @@ -233,8 +234,9 @@ def test_dashboard_digest( use_custom_digest: bool, rls_datasources: list[dict[str, Any]], expected_result: str | Exception, + app_context: None, ) -> None: - from superset import app, security_manager + from superset import security_manager from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.thumbnails.digest import get_dashboard_digest @@ -261,7 +263,7 @@ def test_dashboard_digest( with ( patch.dict( - app.config, + current_app.config, { "THUMBNAIL_EXECUTORS": execute_as, "THUMBNAIL_DASHBOARD_DIGEST_FUNC": func, @@ -372,8 +374,9 @@ def test_chart_digest( use_custom_digest: bool, rls_datasource: dict[str, Any] | None, expected_result: str | Exception, + app_context: None, ) -> None: - from superset import app, security_manager + from superset import security_manager from superset.models.slice import Slice from superset.thumbnails.digest import get_chart_digest @@ -397,7 +400,7 @@ def test_chart_digest( with ( patch.dict( - app.config, + current_app.config, { "THUMBNAIL_EXECUTORS": execute_as, "THUMBNAIL_CHART_DIGEST_FUNC": func, diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index e629f002906..cef52d4f5c1 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -47,6 +47,7 @@ from superset.utils.core import ( QuerySource, remove_extra_adhoc_filters, ) +from tests.conftest import with_config ADHOC_FILTER: QueryObjectFilterClause = { "col": "foo", @@ -622,19 +623,25 @@ def test_get_query_source_from_request( assert get_query_source_from_request() == expected -def test_get_user_agent(mocker: MockerFixture) -> None: +@with_config({"USER_AGENT_FUNC": None}) +def test_get_user_agent(mocker: MockerFixture, app_context: None) -> None: database_mock = mocker.MagicMock() database_mock.database_name = "mydb" - current_app_mock = mocker.patch("superset.utils.core.current_app") - current_app_mock.config = {"USER_AGENT_FUNC": None} - assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "Apache Superset", ( "The default user agent should be returned" ) - current_app_mock.config["USER_AGENT_FUNC"] = ( - lambda database, source: f"{database.database_name} {source.name}" - ) + + +@with_config( + { + "USER_AGENT_FUNC": lambda database, + source: f"{database.database_name} {source.name}" + } +) +def test_get_user_agent_custom(mocker: MockerFixture, app_context: None) -> None: + database_mock = mocker.MagicMock() + database_mock.database_name = "mydb" assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "mydb DASHBOARD", ( "the custom user agent function result should have been returned" diff --git a/tests/unit_tests/utils/test_decorators.py b/tests/unit_tests/utils/test_decorators.py index df767e93870..8073fbb25e8 100644 --- a/tests/unit_tests/utils/test_decorators.py +++ b/tests/unit_tests/utils/test_decorators.py @@ -26,7 +26,6 @@ from unittest.mock import call, Mock, patch import pytest from pytest_mock import MockerFixture -from superset import app from superset.utils import decorators from superset.utils.backports import StrEnum @@ -78,7 +77,7 @@ def test_statsd_gauge( raise FileNotFoundError("Not found") return "OK" - with patch.object(app.config["STATS_LOGGER"], "gauge") as mock: + with patch("superset.extensions.stats_logger_manager.instance.gauge") as mock: cm = ( pytest.raises(expected_exception) if isclass(expected_exception) and issubclass(expected_exception, Exception)