mirror of
https://github.com/apache/superset.git
synced 2026-04-26 19:44:58 +00:00
refactor: Deprecate ensure_user_is_set in favor of override_user (#20502)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from superset.extensions import (
|
|||||||
security_manager,
|
security_manager,
|
||||||
)
|
)
|
||||||
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
||||||
|
from superset.utils.core import override_user
|
||||||
from superset.views.utils import get_datasource_info, get_viz
|
from superset.views.utils import get_datasource_info, get_viz
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -44,16 +45,6 @@ query_timeout = current_app.config[
|
|||||||
] # TODO: new config key
|
] # TODO: new config key
|
||||||
|
|
||||||
|
|
||||||
def ensure_user_is_set(user_id: Optional[int]) -> None:
|
|
||||||
user_is_not_set = not (hasattr(g, "user") and g.user is not None)
|
|
||||||
if user_is_not_set and user_id is not None:
|
|
||||||
# pylint: disable=assigning-non-slot
|
|
||||||
g.user = security_manager.get_user_by_id(user_id)
|
|
||||||
elif user_is_not_set:
|
|
||||||
# pylint: disable=assigning-non-slot
|
|
||||||
g.user = security_manager.get_anonymous_user()
|
|
||||||
|
|
||||||
|
|
||||||
def set_form_data(form_data: Dict[str, Any]) -> None:
|
def set_form_data(form_data: Dict[str, Any]) -> None:
|
||||||
# pylint: disable=assigning-non-slot
|
# pylint: disable=assigning-non-slot
|
||||||
g.form_data = form_data
|
g.form_data = form_data
|
||||||
@@ -76,30 +67,35 @@ def load_chart_data_into_cache(
|
|||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from superset.charts.data.commands.get_data_command import ChartDataCommand
|
from superset.charts.data.commands.get_data_command import ChartDataCommand
|
||||||
|
|
||||||
try:
|
user = (
|
||||||
ensure_user_is_set(job_metadata.get("user_id"))
|
security_manager.get_user_by_id(job_metadata.get("user_id"))
|
||||||
set_form_data(form_data)
|
or security_manager.get_anonymous_user()
|
||||||
query_context = _create_query_context_from_form(form_data)
|
)
|
||||||
command = ChartDataCommand(query_context)
|
|
||||||
result = command.run(cache=True)
|
with override_user(user, force=False):
|
||||||
cache_key = result["cache_key"]
|
try:
|
||||||
result_url = f"/api/v1/chart/data/{cache_key}"
|
set_form_data(form_data)
|
||||||
async_query_manager.update_job(
|
query_context = _create_query_context_from_form(form_data)
|
||||||
job_metadata,
|
command = ChartDataCommand(query_context)
|
||||||
async_query_manager.STATUS_DONE,
|
result = command.run(cache=True)
|
||||||
result_url=result_url,
|
cache_key = result["cache_key"]
|
||||||
)
|
result_url = f"/api/v1/chart/data/{cache_key}"
|
||||||
except SoftTimeLimitExceeded as ex:
|
async_query_manager.update_job(
|
||||||
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
|
job_metadata,
|
||||||
raise ex
|
async_query_manager.STATUS_DONE,
|
||||||
except Exception as ex:
|
result_url=result_url,
|
||||||
# TODO: QueryContext should support SIP-40 style errors
|
)
|
||||||
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
|
except SoftTimeLimitExceeded as ex:
|
||||||
errors = [{"message": error}]
|
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
|
||||||
async_query_manager.update_job(
|
raise ex
|
||||||
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
except Exception as ex:
|
||||||
)
|
# TODO: QueryContext should support SIP-40 style errors
|
||||||
raise ex
|
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
|
||||||
|
errors = [{"message": error}]
|
||||||
|
async_query_manager.update_job(
|
||||||
|
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
|
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
|
||||||
@@ -110,53 +106,61 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
|
|||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
cache_key_prefix = "ejr-" # ejr: explore_json request
|
cache_key_prefix = "ejr-" # ejr: explore_json request
|
||||||
try:
|
|
||||||
ensure_user_is_set(job_metadata.get("user_id"))
|
|
||||||
set_form_data(form_data)
|
|
||||||
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
|
||||||
|
|
||||||
# Perform a deep copy here so that below we can cache the original
|
user = (
|
||||||
# value of the form_data object. This is necessary since the viz
|
security_manager.get_user_by_id(job_metadata.get("user_id"))
|
||||||
# objects modify the form_data object. If the modified version were
|
or security_manager.get_anonymous_user()
|
||||||
# to be cached here, it will lead to a cache miss when clients
|
)
|
||||||
# attempt to retrieve the value of the completed async query.
|
|
||||||
original_form_data = copy.deepcopy(form_data)
|
|
||||||
|
|
||||||
viz_obj = get_viz(
|
with override_user(user, force=False):
|
||||||
datasource_type=cast(str, datasource_type),
|
try:
|
||||||
datasource_id=datasource_id,
|
set_form_data(form_data)
|
||||||
form_data=form_data,
|
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
||||||
force=force,
|
|
||||||
)
|
|
||||||
# run query & cache results
|
|
||||||
payload = viz_obj.get_payload()
|
|
||||||
if viz_obj.has_error(payload):
|
|
||||||
raise SupersetVizException(errors=payload["errors"])
|
|
||||||
|
|
||||||
# Cache the original form_data value for async retrieval
|
# Perform a deep copy here so that below we can cache the original
|
||||||
cache_value = {
|
# value of the form_data object. This is necessary since the viz
|
||||||
"form_data": original_form_data,
|
# objects modify the form_data object. If the modified version were
|
||||||
"response_type": response_type,
|
# to be cached here, it will lead to a cache miss when clients
|
||||||
}
|
# attempt to retrieve the value of the completed async query.
|
||||||
cache_key = generate_cache_key(cache_value, cache_key_prefix)
|
original_form_data = copy.deepcopy(form_data)
|
||||||
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
|
|
||||||
result_url = f"/superset/explore_json/data/{cache_key}"
|
|
||||||
async_query_manager.update_job(
|
|
||||||
job_metadata,
|
|
||||||
async_query_manager.STATUS_DONE,
|
|
||||||
result_url=result_url,
|
|
||||||
)
|
|
||||||
except SoftTimeLimitExceeded as ex:
|
|
||||||
logger.warning("A timeout occurred while loading explore json, error: %s", ex)
|
|
||||||
raise ex
|
|
||||||
except Exception as ex:
|
|
||||||
if isinstance(ex, SupersetVizException):
|
|
||||||
errors = ex.errors # pylint: disable=no-member
|
|
||||||
else:
|
|
||||||
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
|
|
||||||
errors = [error]
|
|
||||||
|
|
||||||
async_query_manager.update_job(
|
viz_obj = get_viz(
|
||||||
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
datasource_type=cast(str, datasource_type),
|
||||||
)
|
datasource_id=datasource_id,
|
||||||
raise ex
|
form_data=form_data,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
# run query & cache results
|
||||||
|
payload = viz_obj.get_payload()
|
||||||
|
if viz_obj.has_error(payload):
|
||||||
|
raise SupersetVizException(errors=payload["errors"])
|
||||||
|
|
||||||
|
# Cache the original form_data value for async retrieval
|
||||||
|
cache_value = {
|
||||||
|
"form_data": original_form_data,
|
||||||
|
"response_type": response_type,
|
||||||
|
}
|
||||||
|
cache_key = generate_cache_key(cache_value, cache_key_prefix)
|
||||||
|
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
|
||||||
|
result_url = f"/superset/explore_json/data/{cache_key}"
|
||||||
|
async_query_manager.update_job(
|
||||||
|
job_metadata,
|
||||||
|
async_query_manager.STATUS_DONE,
|
||||||
|
result_url=result_url,
|
||||||
|
)
|
||||||
|
except SoftTimeLimitExceeded as ex:
|
||||||
|
logger.warning(
|
||||||
|
"A timeout occurred while loading explore json, error: %s", ex
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
if isinstance(ex, SupersetVizException):
|
||||||
|
errors = ex.errors # pylint: disable=no-member
|
||||||
|
else:
|
||||||
|
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
|
||||||
|
errors = [error]
|
||||||
|
|
||||||
|
async_query_manager.update_job(
|
||||||
|
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
|
|||||||
@@ -1453,23 +1453,27 @@ def get_user_id() -> Optional[int]:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def override_user(user: Optional[User]) -> Iterator[Any]:
|
def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]:
|
||||||
"""
|
"""
|
||||||
Temporarily override the current user (if defined) per `flask.g`.
|
Temporarily override the current user per `flask.g` with the specified user.
|
||||||
|
|
||||||
Sometimes, often in the context of async Celery tasks, it is useful to switch the
|
Sometimes, often in the context of async Celery tasks, it is useful to switch the
|
||||||
current user (which may be undefined) to different one, execute some SQLAlchemy
|
current user (which may be undefined) to different one, execute some SQLAlchemy
|
||||||
tasks and then revert back to the original one.
|
tasks et al. and then revert back to the original one.
|
||||||
|
|
||||||
:param user: The override user
|
:param user: The override user
|
||||||
|
:param force: Whether to override the current user if set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=assigning-non-slot
|
# pylint: disable=assigning-non-slot
|
||||||
if hasattr(g, "user"):
|
if hasattr(g, "user"):
|
||||||
current = g.user
|
if force or g.user is None:
|
||||||
g.user = user
|
current = g.user
|
||||||
yield
|
g.user = user
|
||||||
g.user = current
|
yield
|
||||||
|
g.user = current
|
||||||
|
else:
|
||||||
|
yield
|
||||||
else:
|
else:
|
||||||
g.user = user
|
g.user = user
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -562,34 +562,34 @@ def test_get_username(
|
|||||||
assert get_username() == username
|
assert get_username() == username
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("username", [None, "alpha", "gamma"])
|
||||||
"username",
|
@pytest.mark.parametrize("force", [False, True])
|
||||||
[
|
|
||||||
None,
|
|
||||||
"alpha",
|
|
||||||
"gamma",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_override_user(
|
def test_override_user(
|
||||||
app_context: AppContext,
|
app_context: AppContext,
|
||||||
mocker: MockFixture,
|
mocker: MockFixture,
|
||||||
username: str,
|
username: str,
|
||||||
|
force: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
mock_g = mocker.patch("superset.utils.core.g", spec={})
|
mock_g = mocker.patch("superset.utils.core.g", spec={})
|
||||||
admin = security_manager.find_user(username="admin")
|
admin = security_manager.find_user(username="admin")
|
||||||
user = security_manager.find_user(username)
|
user = security_manager.find_user(username)
|
||||||
|
|
||||||
assert not hasattr(mock_g, "user")
|
with override_user(user, force):
|
||||||
|
|
||||||
with override_user(user):
|
|
||||||
assert mock_g.user == user
|
assert mock_g.user == user
|
||||||
|
|
||||||
assert not hasattr(mock_g, "user")
|
assert not hasattr(mock_g, "user")
|
||||||
|
|
||||||
|
mock_g.user = None
|
||||||
|
|
||||||
|
with override_user(user, force):
|
||||||
|
assert mock_g.user == user
|
||||||
|
|
||||||
|
assert mock_g.user is None
|
||||||
|
|
||||||
mock_g.user = admin
|
mock_g.user = admin
|
||||||
|
|
||||||
with override_user(user):
|
with override_user(user, force):
|
||||||
assert mock_g.user == user
|
assert mock_g.user == user if force else admin
|
||||||
|
|
||||||
assert mock_g.user == admin
|
assert mock_g.user == admin
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from superset.exceptions import SupersetException
|
|||||||
from superset.extensions import async_query_manager, security_manager
|
from superset.extensions import async_query_manager, security_manager
|
||||||
from superset.tasks import async_queries
|
from superset.tasks import async_queries
|
||||||
from superset.tasks.async_queries import (
|
from superset.tasks.async_queries import (
|
||||||
ensure_user_is_set,
|
|
||||||
load_chart_data_into_cache,
|
load_chart_data_into_cache,
|
||||||
load_explore_json_into_cache,
|
load_explore_json_into_cache,
|
||||||
)
|
)
|
||||||
@@ -58,12 +57,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
with mock.patch.object(
|
load_chart_data_into_cache(job_metadata, query_context)
|
||||||
async_queries, "ensure_user_is_set"
|
|
||||||
) as ensure_user_is_set:
|
|
||||||
load_chart_data_into_cache(job_metadata, query_context)
|
|
||||||
|
|
||||||
ensure_user_is_set.assert_called_once_with(user.id)
|
|
||||||
mock_set_form_data.assert_called_once_with(query_context)
|
mock_set_form_data.assert_called_once_with(query_context)
|
||||||
mock_update_job.assert_called_once_with(
|
mock_update_job.assert_called_once_with(
|
||||||
job_metadata, "done", result_url=mock.ANY
|
job_metadata, "done", result_url=mock.ANY
|
||||||
@@ -85,11 +79,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
with pytest.raises(ChartDataQueryFailedError):
|
with pytest.raises(ChartDataQueryFailedError):
|
||||||
with mock.patch.object(
|
load_chart_data_into_cache(job_metadata, query_context)
|
||||||
async_queries, "ensure_user_is_set"
|
|
||||||
) as ensure_user_is_set:
|
|
||||||
load_chart_data_into_cache(job_metadata, query_context)
|
|
||||||
ensure_user_is_set.assert_called_once_with(user.id)
|
|
||||||
|
|
||||||
mock_run_command.assert_called_once_with(cache=True)
|
mock_run_command.assert_called_once_with(cache=True)
|
||||||
errors = [{"message": "Error: foo"}]
|
errors = [{"message": "Error: foo"}]
|
||||||
@@ -115,11 +105,11 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
with pytest.raises(SoftTimeLimitExceeded):
|
with pytest.raises(SoftTimeLimitExceeded):
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
async_queries,
|
async_queries,
|
||||||
"ensure_user_is_set",
|
"set_form_data",
|
||||||
) as ensure_user_is_set:
|
) as set_form_data:
|
||||||
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
|
set_form_data.side_effect = SoftTimeLimitExceeded()
|
||||||
load_chart_data_into_cache(job_metadata, form_data)
|
load_chart_data_into_cache(job_metadata, form_data)
|
||||||
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
|
set_form_data.assert_called_once_with(form_data, "error", errors=errors)
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
@mock.patch.object(async_query_manager, "update_job")
|
@mock.patch.object(async_query_manager, "update_job")
|
||||||
@@ -145,12 +135,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
with mock.patch.object(
|
load_explore_json_into_cache(job_metadata, form_data)
|
||||||
async_queries, "ensure_user_is_set"
|
|
||||||
) as ensure_user_is_set:
|
|
||||||
load_explore_json_into_cache(job_metadata, form_data)
|
|
||||||
|
|
||||||
ensure_user_is_set.assert_called_once_with(user.id)
|
|
||||||
mock_update_job.assert_called_once_with(
|
mock_update_job.assert_called_once_with(
|
||||||
job_metadata, "done", result_url=mock.ANY
|
job_metadata, "done", result_url=mock.ANY
|
||||||
)
|
)
|
||||||
@@ -172,11 +157,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(SupersetException):
|
with pytest.raises(SupersetException):
|
||||||
with mock.patch.object(
|
load_explore_json_into_cache(job_metadata, form_data)
|
||||||
async_queries, "ensure_user_is_set"
|
|
||||||
) as ensure_user_is_set:
|
|
||||||
load_explore_json_into_cache(job_metadata, form_data)
|
|
||||||
ensure_user_is_set.assert_called_once_with(user.id)
|
|
||||||
|
|
||||||
mock_set_form_data.assert_called_once_with(form_data)
|
mock_set_form_data.assert_called_once_with(form_data)
|
||||||
errors = ["The dataset associated with this chart no longer exists"]
|
errors = ["The dataset associated with this chart no longer exists"]
|
||||||
@@ -202,49 +183,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
with pytest.raises(SoftTimeLimitExceeded):
|
with pytest.raises(SoftTimeLimitExceeded):
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
async_queries,
|
async_queries,
|
||||||
"ensure_user_is_set",
|
"set_form_data",
|
||||||
) as ensure_user_is_set:
|
) as set_form_data:
|
||||||
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
|
set_form_data.side_effect = SoftTimeLimitExceeded()
|
||||||
load_explore_json_into_cache(job_metadata, form_data)
|
load_explore_json_into_cache(job_metadata, form_data)
|
||||||
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
|
set_form_data.assert_called_once_with(form_data, "error", errors=errors)
|
||||||
|
|
||||||
def test_ensure_user_is_set(self):
|
|
||||||
g_user_is_set = hasattr(g, "user")
|
|
||||||
original_g_user = g.user if g_user_is_set else None
|
|
||||||
|
|
||||||
if g_user_is_set:
|
|
||||||
del g.user
|
|
||||||
|
|
||||||
self.assertFalse(hasattr(g, "user"))
|
|
||||||
ensure_user_is_set(1)
|
|
||||||
self.assertTrue(hasattr(g, "user"))
|
|
||||||
self.assertFalse(g.user.is_anonymous)
|
|
||||||
self.assertEqual(1, get_user_id())
|
|
||||||
|
|
||||||
del g.user
|
|
||||||
|
|
||||||
self.assertFalse(hasattr(g, "user"))
|
|
||||||
ensure_user_is_set(None)
|
|
||||||
self.assertTrue(hasattr(g, "user"))
|
|
||||||
self.assertTrue(g.user.is_anonymous)
|
|
||||||
self.assertEqual(None, get_user_id())
|
|
||||||
|
|
||||||
del g.user
|
|
||||||
|
|
||||||
g.user = security_manager.get_user_by_id(2)
|
|
||||||
self.assertEqual(2, get_user_id())
|
|
||||||
|
|
||||||
ensure_user_is_set(1)
|
|
||||||
self.assertTrue(hasattr(g, "user"))
|
|
||||||
self.assertFalse(g.user.is_anonymous)
|
|
||||||
self.assertEqual(2, get_user_id())
|
|
||||||
|
|
||||||
ensure_user_is_set(None)
|
|
||||||
self.assertTrue(hasattr(g, "user"))
|
|
||||||
self.assertFalse(g.user.is_anonymous)
|
|
||||||
self.assertEqual(2, get_user_id())
|
|
||||||
|
|
||||||
if g_user_is_set:
|
|
||||||
g.user = original_g_user
|
|
||||||
else:
|
|
||||||
del g.user
|
|
||||||
|
|||||||
Reference in New Issue
Block a user