mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
@@ -18,11 +18,16 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, cast, Dict, Optional
|
from typing import Any, cast, Dict, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app, g
|
||||||
|
|
||||||
from superset import app
|
from superset import app
|
||||||
from superset.exceptions import SupersetVizException
|
from superset.exceptions import SupersetVizException
|
||||||
from superset.extensions import async_query_manager, cache_manager, celery_app
|
from superset.extensions import (
|
||||||
|
async_query_manager,
|
||||||
|
cache_manager,
|
||||||
|
celery_app,
|
||||||
|
security_manager,
|
||||||
|
)
|
||||||
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
||||||
from superset.views.utils import get_datasource_info, get_viz
|
from superset.views.utils import get_datasource_info, get_viz
|
||||||
|
|
||||||
@@ -32,6 +37,12 @@ query_timeout = current_app.config[
|
|||||||
] # TODO: new config key
|
] # TODO: new config key
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_user_is_set(user_id: Optional[int]) -> None:
|
||||||
|
user_is_set = hasattr(g, "user") and g.user is not None
|
||||||
|
if not user_is_set and user_id is not None:
|
||||||
|
g.user = security_manager.get_user_by_id(user_id)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
|
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
|
||||||
def load_chart_data_into_cache(
|
def load_chart_data_into_cache(
|
||||||
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
|
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
|
||||||
@@ -42,6 +53,7 @@ def load_chart_data_into_cache(
|
|||||||
|
|
||||||
with app.app_context(): # type: ignore
|
with app.app_context(): # type: ignore
|
||||||
try:
|
try:
|
||||||
|
ensure_user_is_set(job_metadata.get("user_id"))
|
||||||
command = ChartDataCommand()
|
command = ChartDataCommand()
|
||||||
command.set_query_context(form_data)
|
command.set_query_context(form_data)
|
||||||
result = command.run(cache=True)
|
result = command.run(cache=True)
|
||||||
@@ -72,6 +84,7 @@ def load_explore_json_into_cache(
|
|||||||
with app.app_context(): # type: ignore
|
with app.app_context(): # type: ignore
|
||||||
cache_key_prefix = "ejr-" # ejr: explore_json request
|
cache_key_prefix = "ejr-" # ejr: explore_json request
|
||||||
try:
|
try:
|
||||||
|
ensure_user_is_set(job_metadata.get("user_id"))
|
||||||
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
||||||
|
|
||||||
viz_obj = get_viz(
|
viz_obj = get_viz(
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from superset.charts.commands.data import ChartDataCommand
|
|||||||
from superset.charts.commands.exceptions import ChartDataQueryFailedError
|
from superset.charts.commands.exceptions import ChartDataQueryFailedError
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
from superset.extensions import async_query_manager
|
from superset.extensions import async_query_manager, security_manager
|
||||||
|
from superset.tasks import async_queries
|
||||||
from superset.tasks.async_queries import (
|
from superset.tasks.async_queries import (
|
||||||
load_chart_data_into_cache,
|
load_chart_data_into_cache,
|
||||||
load_explore_json_into_cache,
|
load_explore_json_into_cache,
|
||||||
@@ -48,17 +49,24 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
def test_load_chart_data_into_cache(self, mock_update_job):
|
def test_load_chart_data_into_cache(self, mock_update_job):
|
||||||
async_query_manager.init_app(app)
|
async_query_manager.init_app(app)
|
||||||
query_context = get_query_context("birth_names")
|
query_context = get_query_context("birth_names")
|
||||||
|
user = security_manager.find_user("gamma")
|
||||||
job_metadata = {
|
job_metadata = {
|
||||||
"channel_id": str(uuid4()),
|
"channel_id": str(uuid4()),
|
||||||
"job_id": str(uuid4()),
|
"job_id": str(uuid4()),
|
||||||
"user_id": 1,
|
"user_id": user.id,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
load_chart_data_into_cache(job_metadata, query_context)
|
with mock.patch.object(
|
||||||
|
async_queries, "ensure_user_is_set"
|
||||||
|
) as ensure_user_is_set:
|
||||||
|
load_chart_data_into_cache(job_metadata, query_context)
|
||||||
|
|
||||||
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
|
ensure_user_is_set.assert_called_once_with(user.id)
|
||||||
|
mock_update_job.assert_called_once_with(
|
||||||
|
job_metadata, "done", result_url=mock.ANY
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch.object(
|
@mock.patch.object(
|
||||||
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
|
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
|
||||||
@@ -67,25 +75,31 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
|
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
|
||||||
async_query_manager.init_app(app)
|
async_query_manager.init_app(app)
|
||||||
query_context = get_query_context("birth_names")
|
query_context = get_query_context("birth_names")
|
||||||
|
user = security_manager.find_user("gamma")
|
||||||
job_metadata = {
|
job_metadata = {
|
||||||
"channel_id": str(uuid4()),
|
"channel_id": str(uuid4()),
|
||||||
"job_id": str(uuid4()),
|
"job_id": str(uuid4()),
|
||||||
"user_id": 1,
|
"user_id": user.id,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
with pytest.raises(ChartDataQueryFailedError):
|
with pytest.raises(ChartDataQueryFailedError):
|
||||||
load_chart_data_into_cache(job_metadata, query_context)
|
with mock.patch.object(
|
||||||
|
async_queries, "ensure_user_is_set"
|
||||||
|
) as ensure_user_is_set:
|
||||||
|
load_chart_data_into_cache(job_metadata, query_context)
|
||||||
|
ensure_user_is_set.assert_called_once_with(user.id)
|
||||||
|
|
||||||
mock_run_command.assert_called_with(cache=True)
|
mock_run_command.assert_called_once_with(cache=True)
|
||||||
errors = [{"message": "Error: foo"}]
|
errors = [{"message": "Error: foo"}]
|
||||||
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
|
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
@mock.patch.object(async_query_manager, "update_job")
|
@mock.patch.object(async_query_manager, "update_job")
|
||||||
def test_load_explore_json_into_cache(self, mock_update_job):
|
def test_load_explore_json_into_cache(self, mock_update_job):
|
||||||
async_query_manager.init_app(app)
|
async_query_manager.init_app(app)
|
||||||
table = get_table_by_name("birth_names")
|
table = get_table_by_name("birth_names")
|
||||||
|
user = security_manager.find_user("gamma")
|
||||||
form_data = {
|
form_data = {
|
||||||
"datasource": f"{table.id}__table",
|
"datasource": f"{table.id}__table",
|
||||||
"viz_type": "dist_bar",
|
"viz_type": "dist_bar",
|
||||||
@@ -100,29 +114,40 @@ class TestAsyncQueries(SupersetTestCase):
|
|||||||
job_metadata = {
|
job_metadata = {
|
||||||
"channel_id": str(uuid4()),
|
"channel_id": str(uuid4()),
|
||||||
"job_id": str(uuid4()),
|
"job_id": str(uuid4()),
|
||||||
"user_id": 1,
|
"user_id": user.id,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
load_explore_json_into_cache(job_metadata, form_data)
|
with mock.patch.object(
|
||||||
|
async_queries, "ensure_user_is_set"
|
||||||
|
) as ensure_user_is_set:
|
||||||
|
load_explore_json_into_cache(job_metadata, form_data)
|
||||||
|
|
||||||
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
|
ensure_user_is_set.assert_called_once_with(user.id)
|
||||||
|
mock_update_job.assert_called_once_with(
|
||||||
|
job_metadata, "done", result_url=mock.ANY
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch.object(async_query_manager, "update_job")
|
@mock.patch.object(async_query_manager, "update_job")
|
||||||
def test_load_explore_json_into_cache_error(self, mock_update_job):
|
def test_load_explore_json_into_cache_error(self, mock_update_job):
|
||||||
async_query_manager.init_app(app)
|
async_query_manager.init_app(app)
|
||||||
|
user = security_manager.find_user("gamma")
|
||||||
form_data = {}
|
form_data = {}
|
||||||
job_metadata = {
|
job_metadata = {
|
||||||
"channel_id": str(uuid4()),
|
"channel_id": str(uuid4()),
|
||||||
"job_id": str(uuid4()),
|
"job_id": str(uuid4()),
|
||||||
"user_id": 1,
|
"user_id": user.id,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(SupersetException):
|
with pytest.raises(SupersetException):
|
||||||
load_explore_json_into_cache(job_metadata, form_data)
|
with mock.patch.object(
|
||||||
|
async_queries, "ensure_user_is_set"
|
||||||
|
) as ensure_user_is_set:
|
||||||
|
load_explore_json_into_cache(job_metadata, form_data)
|
||||||
|
ensure_user_is_set.assert_called_once_with(user.id)
|
||||||
|
|
||||||
errors = ["The dataset associated with this chart no longer exists"]
|
errors = ["The dataset associated with this chart no longer exists"]
|
||||||
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
|
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
|
||||||
|
|||||||
Reference in New Issue
Block a user