diff --git a/superset/models/core.py b/superset/models/core.py index dc8df7ca13d..28f454295f1 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -36,7 +36,7 @@ import numpy import pandas as pd import sqlalchemy as sqla import sshtunnel -from flask import g, request +from flask import g from flask_appbuilder import Model from marshmallow.exceptions import ValidationError from sqlalchemy import ( @@ -84,7 +84,7 @@ from superset.superset_typing import ( ) from superset.utils import cache as cache_util, core as utils, json from superset.utils.backports import StrEnum -from superset.utils.core import get_username +from superset.utils.core import get_query_source_from_request, get_username from superset.utils.oauth2 import ( check_for_oauth2, get_oauth2_access_token, @@ -537,13 +537,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable self.update_params_from_encrypted_extra(params) if DB_CONNECTION_MUTATOR: - if not source and request and request.referrer: - if "/superset/dashboard/" in request.referrer: - source = utils.QuerySource.DASHBOARD - elif "/explore/" in request.referrer: - source = utils.QuerySource.CHART - elif "/sqllab/" in request.referrer: - source = utils.QuerySource.SQL_LAB + source = source or get_query_source_from_request() sqlalchemy_url, params = DB_CONNECTION_MUTATOR( sqlalchemy_url, diff --git a/superset/utils/core.py b/superset/utils/core.py index 9e1bee5ecb5..69de1707ede 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1805,7 +1805,20 @@ def to_int(v: Any, value_if_invalid: int = 0) -> int: return value_if_invalid +def get_query_source_from_request() -> QuerySource | None: + if not request or not request.referrer: + return None + if "/superset/dashboard/" in request.referrer: + return QuerySource.DASHBOARD + if "/explore/" in request.referrer: + return QuerySource.CHART + if "/sqllab/" in request.referrer: + return QuerySource.SQL_LAB + return None + + def get_user_agent(database: Database, source: QuerySource | None) -> str: + source = source or get_query_source_from_request() if user_agent_func := current_app.config["USER_AGENT_FUNC"]: return user_agent_func(database, source) diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index 74b9373392a..a670c98307d 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -31,6 +31,7 @@ from superset.utils.core import ( generic_find_constraint_name, generic_find_fk_constraint_name, get_datasource_full_name, + get_query_source_from_request, get_user_agent, is_test, normalize_dttm_col, @@ -401,6 +402,28 @@ def test_get_datasource_full_name(): ) +@pytest.mark.parametrize( + "referrer,expected", + [ + (None, None), + ("https://mysuperset.com/abc", None), + ("https://mysuperset.com/superset/dashboard/", QuerySource.DASHBOARD), + ("https://mysuperset.com/explore/", QuerySource.CHART), + ("https://mysuperset.com/sqllab/", QuerySource.SQL_LAB), + ], +) +def test_get_query_source_from_request( + referrer: str | None, + expected: QuerySource | None, + mocker: MockerFixture, +) -> None: + if referrer: + request_mock = mocker.patch("superset.utils.core.request") + request_mock.referrer = referrer + + assert get_query_source_from_request() == expected + + def test_get_user_agent(mocker: MockerFixture) -> None: database_mock = mocker.MagicMock() database_mock.database_name = "mydb"