mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
chore: proper current_app.config proxy usage (#34345)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
6c9cda758a
commit
cb27d5fe8d
@@ -38,10 +38,11 @@ from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy.dialects
|
||||
from flask import current_app as app
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
from sqlalchemy.exc import NoSuchModuleError
|
||||
|
||||
from superset import app, feature_flag_manager
|
||||
from superset import feature_flag_manager
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,7 +42,7 @@ import requests
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from deprecation import deprecated
|
||||
from flask import current_app, g, url_for
|
||||
from flask import current_app as app, g, url_for
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
@@ -459,7 +459,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
def is_oauth2_enabled(cls) -> bool:
|
||||
return (
|
||||
cls.supports_oauth2
|
||||
and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"]
|
||||
and cls.engine_name in app.config["DATABASE_OAUTH2_CLIENTS"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -512,12 +512,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Build the DB engine spec level OAuth2 client config.
|
||||
"""
|
||||
oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"]
|
||||
oauth2_config = app.config["DATABASE_OAUTH2_CLIENTS"]
|
||||
if cls.engine_name not in oauth2_config:
|
||||
return None
|
||||
|
||||
db_engine_spec_config = oauth2_config[cls.engine_name]
|
||||
redirect_uri = current_app.config.get(
|
||||
redirect_uri = app.config.get(
|
||||
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||
url_for("DatabaseRestApi.oauth2", _external=True),
|
||||
)
|
||||
@@ -573,7 +573,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Exchange authorization code for refresh/access tokens.
|
||||
"""
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
req_body = {
|
||||
"code": code,
|
||||
@@ -595,7 +595,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Refresh an access token that has expired.
|
||||
"""
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
req_body = {
|
||||
"client_id": config["id"],
|
||||
@@ -871,7 +871,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
ret_list = []
|
||||
time_grains = builtin_time_grains.copy()
|
||||
time_grains.update(current_app.config["TIME_GRAIN_ADDONS"])
|
||||
time_grains.update(app.config["TIME_GRAIN_ADDONS"])
|
||||
for duration, func in cls.get_time_grain_expressions().items():
|
||||
if duration in time_grains:
|
||||
name = time_grains[duration]
|
||||
@@ -950,9 +950,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
# TODO: use @memoize decorator or similar to avoid recomputation on every call
|
||||
time_grain_expressions = cls._time_grain_expressions.copy()
|
||||
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
|
||||
grain_addon_expressions = app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
|
||||
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
|
||||
denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"]
|
||||
denylist: list[str] = app.config["TIME_GRAIN_DENYLIST"]
|
||||
for key in denylist:
|
||||
time_grain_expressions.pop(key, None)
|
||||
|
||||
@@ -2235,7 +2235,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
:param sqlalchemy_uri:
|
||||
"""
|
||||
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
|
||||
if db_engine_uri_validator := app.config["DB_SQLA_URI_VALIDATOR"]:
|
||||
db_engine_uri_validator(sqlalchemy_uri)
|
||||
|
||||
if existing_disallowed := cls.disallow_uri_query_params.get(
|
||||
|
||||
@@ -22,7 +22,7 @@ from datetime import datetime
|
||||
from typing import Any, cast, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
from flask import current_app
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
@@ -247,7 +247,7 @@ try:
|
||||
)
|
||||
set_setting(
|
||||
"product_name",
|
||||
f"superset/{current_app.config.get('VERSION_STRING', 'dev')}",
|
||||
f"superset/{app.config.get('VERSION_STRING', 'dev')}",
|
||||
)
|
||||
except ImportError: # ClickHouse Connect not installed, do nothing
|
||||
pass
|
||||
|
||||
@@ -24,13 +24,13 @@ from typing import Any, TYPE_CHECKING, TypedDict
|
||||
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset.config import VERSION_STRING
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
@@ -252,7 +252,8 @@ class DuckDBEngineSpec(DuckDBParametersMixin, BaseEngineSpec):
|
||||
delim = " " if custom_user_agent else ""
|
||||
user_agent = get_user_agent(database, source)
|
||||
user_agent = user_agent.replace(" ", "-").lower()
|
||||
user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}"
|
||||
version_string = app.config["VERSION_STRING"]
|
||||
user_agent = f"{user_agent}/{version_string}{delim}{custom_user_agent}"
|
||||
config.setdefault("custom_user_agent", user_agent)
|
||||
|
||||
return extra
|
||||
|
||||
@@ -29,7 +29,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from flask import current_app, g
|
||||
from flask import current_app as app, g
|
||||
from sqlalchemy import Column, text, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@@ -66,7 +66,7 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
|
||||
import boto3 # pylint: disable=all
|
||||
from boto3.s3.transfer import TransferConfig # pylint: disable=all
|
||||
|
||||
bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
|
||||
bucket_path = app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
|
||||
|
||||
if not bucket_path:
|
||||
logger.info("No upload bucket specified")
|
||||
@@ -224,7 +224,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
|
||||
dir=app.config["UPLOAD_FOLDER"], suffix=".parquet"
|
||||
) as file:
|
||||
pq.write_table(pa.Table.from_pandas(df), where=file.name)
|
||||
|
||||
@@ -243,9 +243,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
),
|
||||
location=upload_to_s3(
|
||||
filename=file.name,
|
||||
upload_prefix=current_app.config[
|
||||
"CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"
|
||||
](database, g.user, table.schema),
|
||||
upload_prefix=app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
|
||||
database, g.user, table.schema
|
||||
),
|
||||
table=table,
|
||||
),
|
||||
)
|
||||
@@ -401,12 +401,13 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
last_log_line = len(log_lines)
|
||||
if needs_commit:
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
|
||||
if sleep_interval := app.config.get("HIVE_POLL_INTERVAL"):
|
||||
logger.warning(
|
||||
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" # noqa: E501
|
||||
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. "
|
||||
"Please use DB_POLL_INTERVAL_SECONDS instead"
|
||||
)
|
||||
else:
|
||||
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
|
||||
sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get(
|
||||
cls.engine, 5
|
||||
)
|
||||
time.sleep(sleep_interval)
|
||||
|
||||
@@ -24,7 +24,7 @@ from datetime import datetime
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask import current_app as app
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
@@ -155,7 +155,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
||||
|
||||
if needs_commit:
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
|
||||
sleep_interval = app.config["DB_POLL_INTERVAL_SECONDS"].get(
|
||||
cls.engine, 5
|
||||
)
|
||||
time.sleep(sleep_interval)
|
||||
|
||||
@@ -28,10 +28,9 @@ with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be inst
|
||||
# Ensure pyocient inherits Superset's logging level
|
||||
import geojson
|
||||
import pyocient
|
||||
from flask import current_app as app
|
||||
from shapely import wkt
|
||||
|
||||
from superset import app
|
||||
|
||||
superset_log_level = app.config["LOG_LEVEL"]
|
||||
pyocient.logger.setLevel(superset_log_level)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from typing import Any, cast, Optional, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
from flask import current_app
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from packaging.version import Version
|
||||
from sqlalchemy import Column, literal_column, types
|
||||
@@ -1318,7 +1318,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
|
||||
query_id = query.id
|
||||
poll_interval = query.database.connect_args.get(
|
||||
"poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]
|
||||
"poll_interval", app.config["PRESTO_POLL_INTERVAL"]
|
||||
)
|
||||
logger.info("Query %i: Polling the cursor for progress", query_id)
|
||||
polled = cursor.poll()
|
||||
|
||||
@@ -21,7 +21,7 @@ from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from urllib import parse
|
||||
|
||||
from flask import current_app
|
||||
from flask import current_app as app
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
@@ -498,8 +498,8 @@ class SingleStoreSpec(BasicParametersMixin, BaseEngineSpec):
|
||||
"conn_attrs",
|
||||
{
|
||||
"_connector_name": "SingleStore Superset Database Engine",
|
||||
"_connector_version": current_app.config.get("VERSION_STRING", "dev"),
|
||||
"_product_version": current_app.config.get("VERSION_STRING", "dev"),
|
||||
"_connector_version": app.config.get("VERSION_STRING", "dev"),
|
||||
"_product_version": app.config.get("VERSION_STRING", "dev"),
|
||||
},
|
||||
)
|
||||
return uri, connect_args
|
||||
|
||||
@@ -27,7 +27,7 @@ from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from flask import current_app
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import types
|
||||
@@ -411,9 +411,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
)
|
||||
connect_args["private_key"] = pkb
|
||||
else:
|
||||
allowed_extra_auths = current_app.config[
|
||||
"ALLOWED_EXTRA_AUTHENTICATIONS"
|
||||
].get("snowflake", {})
|
||||
allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get(
|
||||
"snowflake", {}
|
||||
)
|
||||
if auth_method in allowed_extra_auths:
|
||||
snowflake_auth = allowed_extra_auths.get(auth_method)
|
||||
else:
|
||||
|
||||
@@ -23,7 +23,7 @@ import time
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from flask import copy_current_request_context, ctx, current_app, Flask, g
|
||||
from flask import copy_current_request_context, ctx, current_app as app, Flask, g
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
@@ -249,7 +249,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||
args=(
|
||||
execute_result,
|
||||
execute_event,
|
||||
current_app._get_current_object(), # pylint: disable=protected-access
|
||||
app._get_current_object(), # pylint: disable=protected-access
|
||||
g._get_current_object(), # pylint: disable=protected-access
|
||||
),
|
||||
)
|
||||
@@ -352,9 +352,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||
elif auth_method == "jwt":
|
||||
from trino.auth import JWTAuthentication as trino_auth # noqa
|
||||
else:
|
||||
allowed_extra_auths = current_app.config[
|
||||
"ALLOWED_EXTRA_AUTHENTICATIONS"
|
||||
].get("trino", {})
|
||||
allowed_extra_auths = app.config["ALLOWED_EXTRA_AUTHENTICATIONS"].get(
|
||||
"trino", {}
|
||||
)
|
||||
if auth_method in allowed_extra_auths:
|
||||
trino_auth = allowed_extra_auths.get(auth_method)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user