chore: proper current_app.config proxy usage (#34345)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Maxime Beauchemin
2025-07-31 19:27:42 -07:00
committed by GitHub
parent 6c9cda758a
commit cb27d5fe8d
144 changed files with 1428 additions and 1119 deletions

View File

@@ -38,10 +38,11 @@ from pathlib import Path
from typing import Any, Optional
import sqlalchemy.dialects
from flask import current_app as app
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.exc import NoSuchModuleError
from superset import app, feature_flag_manager
from superset import feature_flag_manager
from superset.db_engine_specs.base import BaseEngineSpec
logger = logging.getLogger(__name__)

View File

@@ -42,7 +42,7 @@ import requests
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask import current_app, g, url_for
from flask import current_app as app, g, url_for
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
@@ -459,7 +459,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def is_oauth2_enabled(cls) -> bool:
return (
cls.supports_oauth2
and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"]
and cls.engine_name in app.config["DATABASE_OAUTH2_CLIENTS"]
)
@classmethod
@@ -512,12 +512,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Build the DB engine spec level OAuth2 client config.
"""
oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"]
oauth2_config = app.config["DATABASE_OAUTH2_CLIENTS"]
if cls.engine_name not in oauth2_config:
return None
db_engine_spec_config = oauth2_config[cls.engine_name]
redirect_uri = current_app.config.get(
redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
@@ -573,7 +573,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Exchange authorization code for refresh/access tokens.
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
"code": code,
@@ -595,7 +595,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Refresh an access token that has expired.
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
"client_id": config["id"],
@@ -871,7 +871,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
ret_list = []
time_grains = builtin_time_grains.copy()
time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
time_grains.update(app.config["TIME_GRAIN_ADDONS"])
for duration, func in cls.get_time_grain_expressions().items():
if duration in time_grains:
name = time_grains[duration]
@@ -950,9 +950,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
# TODO: use @memoize decorator or similar to avoid recomputation on every call
time_grain_expressions = cls._time_grain_expressions.copy()
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
grain_addon_expressions = app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"]
denylist: list[str] = app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key, None)
@@ -2235,7 +2235,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sqlalchemy_uri:
"""
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
if db_engine_uri_validator := app.config["DB_SQLA_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)
if existing_disallowed := cls.disallow_uri_query_params.get(

View File

@@ -22,7 +22,7 @@ from datetime import datetime
from typing import Any, cast, TYPE_CHECKING
from urllib import parse
from flask import current_app
from flask import current_app as app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
@@ -247,7 +247,7 @@ try:
)
set_setting(
"product_name",
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
f"superset/{app.config.get('VERSION_STRING', 'dev')}",
)
except ImportError: # ClickHouse Connect not installed, do nothing
pass

View File

@@ -24,13 +24,13 @@ from typing import Any, TYPE_CHECKING, TypedDict
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import current_app as app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from superset.config import VERSION_STRING
from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
@@ -252,7 +252,8 @@ class DuckDBEngineSpec(DuckDBParametersMixin, BaseEngineSpec):
delim = " " if custom_user_agent else ""
user_agent = get_user_agent(database, source)
user_agent = user_agent.replace(" ", "-").lower()
user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}"
version_string = app.config["VERSION_STRING"]
user_agent = f"{user_agent}/{version_string}{delim}{custom_user_agent}"
config.setdefault("custom_user_agent", user_agent)
return extra

View File

@@ -29,7 +29,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from flask import current_app, g
from flask import current_app as app, g
from sqlalchemy import Column, text, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
@@ -66,7 +66,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
import boto3 # pylint: disable=all
from boto3.s3.transfer import TransferConfig # pylint: disable=all
bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
bucket_path = app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
if not bucket_path:
logger.info("No upload bucket specified")
@@ -224,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
)
with tempfile.NamedTemporaryFile(
dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
dir=app.config["UPLOAD_FOLDER"], suffix=".parquet"
) as file:
pq.write_table(pa.Table.from_pandas(df), where=file.name)
@@ -243,9 +243,9 @@ class HiveEngineSpec(PrestoEngineSpec):
),
location=upload_to_s3(
filename=file.name,
upload_prefix=current_app.config[
"CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
](database, g.user, table.schema),
upload_prefix=app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
database, g.user, table.schema
),
table=table,
),
)
@@ -401,12 +401,13 @@ class HiveEngineSpec(PrestoEngineSpec):
last_log_line = len(log_lines)
if needs_commit:
db.session.commit() # pylint: disable=consider-using-transaction
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
if sleep_interval := app.config.get("HIVE_POLL_INTERVAL"):
logger.warning(
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" # noqa: E501
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. "
"Please use DB_POLL_INTERVAL_SECONDS instead"
)
else:
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get(
cls.engine, 5
)
time.sleep(sleep_interval)

View File

@@ -24,7 +24,7 @@ from datetime import datetime
from typing import Any, Optional, TYPE_CHECKING
import requests
from flask import current_app
from flask import current_app as app
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
@@ -155,7 +155,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
if needs_commit:
db.session.commit() # pylint: disable=consider-using-transaction
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get(
cls.engine, 5
)
time.sleep(sleep_interval)

View File

@@ -28,10 +28,9 @@ with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be inst
# Ensure pyocient inherits Superset's logging level
import geojson
import pyocient
from flask import current_app as app
from shapely import wkt
from superset import app
superset_log_level = app.config["LOG_LEVEL"]
pyocient.logger.setLevel(superset_log_level)

View File

@@ -30,7 +30,7 @@ from typing import Any, cast, Optional, TYPE_CHECKING
from urllib import parse
import pandas as pd
from flask import current_app
from flask import current_app as app
from flask_babel import gettext as __, lazy_gettext as _
from packaging.version import Version
from sqlalchemy import Column, literal_column, types
@@ -1318,7 +1318,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
query_id = query.id
poll_interval = query.database.connect_args.get(
"poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
"poll_interval", app.config["PRESTO_POLL_INTERVAL"]
)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()

View File

@@ -21,7 +21,7 @@ from datetime import datetime
from typing import Any, Optional
from urllib import parse
from flask import current_app
from flask import current_app as app
from sqlalchemy import types
from sqlalchemy.engine import URL
@@ -498,8 +498,8 @@ class SingleStoreSpec(BasicParametersMixin, BaseEngineSpec):
"conn_attrs",
{
"_connector_name": "SingleStore Superset Database Engine",
"_connector_version": current_app.config.get("VERSION_STRING", "dev"),
"_product_version": current_app.config.get("VERSION_STRING", "dev"),
"_connector_version": app.config.get("VERSION_STRING", "dev"),
"_product_version": app.config.get("VERSION_STRING", "dev"),
},
)
return uri, connect_args

View File

@@ -27,7 +27,7 @@ from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from flask import current_app
from flask import current_app as app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from sqlalchemy import types
@@ -411,9 +411,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
)
connect_args["private_key"] = pkb
else:
allowed_extra_auths = current_app.config[
"ALLOWED_EXTRA_AUTHENTICATIONS"
].get("snowflake", {})
allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get(
"snowflake", {}
)
if auth_method in allowed_extra_auths:
snowflake_auth = allowed_extra_auths.get(auth_method)
else:

View File

@@ -23,7 +23,7 @@ import time
from typing import Any, TYPE_CHECKING
import requests
from flask import copy_current_request_context, ctx, current_app, Flask, g
from flask import copy_current_request_context, ctx, current_app as app, Flask, g
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
@@ -249,7 +249,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
args=(
execute_result,
execute_event,
current_app._get_current_object(), # pylint: disable=protected-access
app._get_current_object(), # pylint: disable=protected-access
g._get_current_object(), # pylint: disable=protected-access
),
)
@@ -352,9 +352,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
elif auth_method == "jwt":
from trino.auth import JWTAuthentication as trino_auth # noqa
else:
allowed_extra_auths = current_app.config[
"ALLOWED_EXTRA_AUTHENTICATIONS"
].get("trino", {})
allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get(
"trino", {}
)
if auth_method in allowed_extra_auths:
trino_auth = allowed_extra_auths.get(auth_method)
else: