fix: always extract query source from request (#32525)

This commit is contained in:
Ville Brofeldt
2025-03-06 14:17:21 -08:00
committed by GitHub
parent 99238dccbb
commit 68e8d9858c
3 changed files with 39 additions and 9 deletions

View File

@@ -36,7 +36,7 @@ import numpy
import pandas as pd
import sqlalchemy as sqla
import sshtunnel
from flask import g, request
from flask import g
from flask_appbuilder import Model
from marshmallow.exceptions import ValidationError
from sqlalchemy import (
@@ -84,7 +84,7 @@ from superset.superset_typing import (
)
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
from superset.utils.core import get_query_source_from_request, get_username
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
@@ -537,13 +537,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
self.update_params_from_encrypted_extra(params)
if DB_CONNECTION_MUTATOR:
if not source and request and request.referrer:
if "/superset/dashboard/" in request.referrer:
source = utils.QuerySource.DASHBOARD
elif "/explore/" in request.referrer:
source = utils.QuerySource.CHART
elif "/sqllab/" in request.referrer:
source = utils.QuerySource.SQL_LAB
source = source or get_query_source_from_request()
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url,

View File

@@ -1805,7 +1805,20 @@ def to_int(v: Any, value_if_invalid: int = 0) -> int:
return value_if_invalid
def get_query_source_from_request() -> QuerySource | None:
if not request or not request.referrer:
return None
if "/superset/dashboard/" in request.referrer:
return QuerySource.DASHBOARD
if "/explore/" in request.referrer:
return QuerySource.CHART
if "/sqllab/" in request.referrer:
return QuerySource.SQL_LAB
return None
def get_user_agent(database: Database, source: QuerySource | None) -> str:
source = source or get_query_source_from_request()
if user_agent_func := current_app.config["USER_AGENT_FUNC"]:
return user_agent_func(database, source)

View File

@@ -31,6 +31,7 @@ from superset.utils.core import (
generic_find_constraint_name,
generic_find_fk_constraint_name,
get_datasource_full_name,
get_query_source_from_request,
get_user_agent,
is_test,
normalize_dttm_col,
@@ -401,6 +402,28 @@ def test_get_datasource_full_name():
)
@pytest.mark.parametrize(
"referrer,expected",
[
(None, None),
("https://mysuperset.com/abc", None),
("https://mysuperset.com/superset/dashboard/", QuerySource.DASHBOARD),
("https://mysuperset.com/explore/", QuerySource.CHART),
("https://mysuperset.com/sqllab/", QuerySource.SQL_LAB),
],
)
def test_get_query_source_from_request(
referrer: str | None,
expected: QuerySource | None,
mocker: MockerFixture,
) -> None:
if referrer:
request_mock = mocker.patch("superset.utils.core.request")
request_mock.referrer = referrer
assert get_query_source_from_request() == expected
def test_get_user_agent(mocker: MockerFixture) -> None:
database_mock = mocker.MagicMock()
database_mock.database_name = "mydb"