fix(#13378): Ensure g.user is set for impersonation (#13878)

This commit is contained in:
Ben Reinhart
2021-03-31 11:22:56 -07:00
committed by GitHub
parent 507041e93b
commit ca506e9396
2 changed files with 54 additions and 16 deletions

View File

@@ -18,11 +18,16 @@
import logging import logging
from typing import Any, cast, Dict, Optional from typing import Any, cast, Dict, Optional
from flask import current_app from flask import current_app, g
from superset import app from superset import app
from superset.exceptions import SupersetVizException from superset.exceptions import SupersetVizException
from superset.extensions import async_query_manager, cache_manager, celery_app from superset.extensions import (
async_query_manager,
cache_manager,
celery_app,
security_manager,
)
from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.views.utils import get_datasource_info, get_viz from superset.views.utils import get_datasource_info, get_viz
@@ -32,6 +37,12 @@ query_timeout = current_app.config[
] # TODO: new config key ] # TODO: new config key
def ensure_user_is_set(user_id: Optional[int]) -> None:
user_is_set = hasattr(g, "user") and g.user is not None
if not user_is_set and user_id is not None:
g.user = security_manager.get_user_by_id(user_id)
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache( def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any], job_metadata: Dict[str, Any], form_data: Dict[str, Any],
@@ -42,6 +53,7 @@ def load_chart_data_into_cache(
with app.app_context(): # type: ignore with app.app_context(): # type: ignore
try: try:
ensure_user_is_set(job_metadata.get("user_id"))
command = ChartDataCommand() command = ChartDataCommand()
command.set_query_context(form_data) command.set_query_context(form_data)
result = command.run(cache=True) result = command.run(cache=True)
@@ -72,6 +84,7 @@ def load_explore_json_into_cache(
with app.app_context(): # type: ignore with app.app_context(): # type: ignore
cache_key_prefix = "ejr-" # ejr: explore_json request cache_key_prefix = "ejr-" # ejr: explore_json request
try: try:
ensure_user_is_set(job_metadata.get("user_id"))
datasource_id, datasource_type = get_datasource_info(None, None, form_data) datasource_id, datasource_type = get_datasource_info(None, None, form_data)
viz_obj = get_viz( viz_obj = get_viz(

View File

@@ -26,7 +26,8 @@ from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.exceptions import ChartDataQueryFailedError from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException from superset.exceptions import SupersetException
from superset.extensions import async_query_manager from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import ( from superset.tasks.async_queries import (
load_chart_data_into_cache, load_chart_data_into_cache,
load_explore_json_into_cache, load_explore_json_into_cache,
@@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase):
def test_load_chart_data_into_cache(self, mock_update_job): def test_load_chart_data_into_cache(self, mock_update_job):
async_query_manager.init_app(app) async_query_manager.init_app(app)
query_context = get_query_context("birth_names") query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = { job_metadata = {
"channel_id": str(uuid4()), "channel_id": str(uuid4()),
"job_id": str(uuid4()), "job_id": str(uuid4()),
"user_id": 1, "user_id": user.id,
"status": "pending", "status": "pending",
"errors": [], "errors": [],
} }
load_chart_data_into_cache(job_metadata, query_context) with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@mock.patch.object( @mock.patch.object(
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo") ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
@@ -67,25 +75,31 @@ class TestAsyncQueries(SupersetTestCase):
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
async_query_manager.init_app(app) async_query_manager.init_app(app)
query_context = get_query_context("birth_names") query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = { job_metadata = {
"channel_id": str(uuid4()), "channel_id": str(uuid4()),
"job_id": str(uuid4()), "job_id": str(uuid4()),
"user_id": 1, "user_id": user.id,
"status": "pending", "status": "pending",
"errors": [], "errors": [],
} }
with pytest.raises(ChartDataQueryFailedError): with pytest.raises(ChartDataQueryFailedError):
load_chart_data_into_cache(job_metadata, query_context) with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)
mock_run_command.assert_called_with(cache=True) mock_run_command.assert_called_once_with(cache=True)
errors = [{"message": "Error: foo"}] errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors) mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job") @mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job): def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app) async_query_manager.init_app(app)
table = get_table_by_name("birth_names") table = get_table_by_name("birth_names")
user = security_manager.find_user("gamma")
form_data = { form_data = {
"datasource": f"{table.id}__table", "datasource": f"{table.id}__table",
"viz_type": "dist_bar", "viz_type": "dist_bar",
@@ -100,29 +114,40 @@ class TestAsyncQueries(SupersetTestCase):
job_metadata = { job_metadata = {
"channel_id": str(uuid4()), "channel_id": str(uuid4()),
"job_id": str(uuid4()), "job_id": str(uuid4()),
"user_id": 1, "user_id": user.id,
"status": "pending", "status": "pending",
"errors": [], "errors": [],
} }
load_explore_json_into_cache(job_metadata, form_data) with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) ensure_user_is_set.assert_called_once_with(user.id)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@mock.patch.object(async_query_manager, "update_job") @mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job): def test_load_explore_json_into_cache_error(self, mock_update_job):
async_query_manager.init_app(app) async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {} form_data = {}
job_metadata = { job_metadata = {
"channel_id": str(uuid4()), "channel_id": str(uuid4()),
"job_id": str(uuid4()), "job_id": str(uuid4()),
"user_id": 1, "user_id": user.id,
"status": "pending", "status": "pending",
"errors": [], "errors": [],
} }
with pytest.raises(SupersetException): with pytest.raises(SupersetException):
load_explore_json_into_cache(job_metadata, form_data) with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)
errors = ["The dataset associated with this chart no longer exists"] errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors) mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)