refactor: sql_json view endpoint (#16441)

* refactor sql_json view endpoint

* fix pylint

* renaming

* renaming

Co-authored-by: Amit Miran <47772523+amitmiran137@users.noreply.github.com>
This commit is contained in:
ofekisr
2021-08-25 15:51:48 +03:00
committed by GitHub
parent 6a2cec51c5
commit 93c60e4021
2 changed files with 149 additions and 56 deletions

View File

@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any, cast, Dict, Optional
from flask import g
from superset import app, is_feature_enabled
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
SqlResults = Dict[str, Any]
@dataclass # pylint: disable=R0902
class SqlJsonExecutionContext:
database_id: int
schema: str
sql: str
template_params: Dict[str, Any]
async_flag: bool
limit: int
status: str
select_as_cta: bool
ctas_method: CtasMethod
tmp_table_name: str
client_id: str
client_id_or_short_id: str
sql_editor_id: str
tab_name: str
user_id: Optional[int]
expand_data: bool
def __init__(self, query_params: Dict[str, Any]):
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
self.database_id = cast(int, query_params.get("database_id"))
self.schema = cast(str, query_params.get("schema"))
self.sql = cast(str, query_params.get("sql"))
self.template_params = self._get_template_params(query_params)
self.async_flag = cast(bool, query_params.get("runAsync"))
self.limit = self._get_limit_param(query_params)
self.status = cast(str, query_params.get("status"))
self.select_as_cta = cast(bool, query_params.get("select_as_cta"))
self.ctas_method = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
self.tmp_table_name = cast(str, query_params.get("tmp_table_name"))
self.client_id = cast(str, query_params.get("client_id"))
self.sql_editor_id = cast(str, query_params.get("sql_editor_id"))
self.tab_name = cast(str, query_params.get("tab"))
self.expand_data: bool = cast(
bool,
is_feature_enabled("PRESTO_EXPAND_DATA")
and query_params.get("expand_data"),
)
@staticmethod
def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]:
try:
template_params = json.loads(query_params.get("templateParams") or "{}")
except json.JSONDecodeError:
logger.warning(
"Invalid template parameter %s" " specified. Defaulting to empty dict",
str(query_params.get("templateParams")),
)
template_params = {}
return template_params
@staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit
)
limit = 0
return limit
def _get_user_id(self) -> Optional[int]: # pylint: disable=R0201
try:
return g.user.get_id() if g.user else None
except RuntimeError:
return None
def is_run_asynchronous(self) -> bool:
return self.async_flag

View File

@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=comparison-with-callable, line-too-long, too-many-branches
# pylint: disable=comparison-with-callable, line-too-long
import dataclasses
import logging
import re
@@ -100,7 +100,7 @@ from superset.models.sql_lab import LimitingFactor, Query, TabState
from superset.models.user_attributes import UserAttribute
from superset.queries.dao import QueryDAO
from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.tasks.async_queries import load_explore_json_into_cache
from superset.typing import FlaskResponse
@@ -110,6 +110,7 @@ from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
from superset.views.base import (
api,
BaseSupersetView,
@@ -2577,42 +2578,16 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
return self.sql_json_exec(request.json, log_params)
execution_context = SqlJsonExecutionContext(request.json)
return self.sql_json_exec(execution_context, request.json, log_params)
def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals
self, query_params: Dict[str, Any], log_params: Optional[Dict[str, Any]] = None
self,
execution_context: SqlJsonExecutionContext,
query_params: Dict[str, Any],
log_params: Optional[Dict[str, Any]] = None,
) -> FlaskResponse:
"""Runs arbitrary sql and returns data as json"""
# Collect Values
database_id: int = cast(int, query_params.get("database_id"))
schema: str = cast(str, query_params.get("schema"))
sql: str = cast(str, query_params.get("sql"))
try:
template_params = json.loads(query_params.get("templateParams") or "{}")
except json.JSONDecodeError:
logger.warning(
"Invalid template parameter %s" " specified. Defaulting to empty dict",
str(query_params.get("templateParams")),
)
template_params = {}
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
async_flag: bool = cast(bool, query_params.get("runAsync"))
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit
)
limit = 0
select_as_cta: bool = cast(bool, query_params.get("select_as_cta"))
ctas_method: CtasMethod = cast(
CtasMethod, query_params.get("ctas_method", CtasMethod.TABLE)
)
tmp_table_name: str = cast(str, query_params.get("tmp_table_name"))
client_id: str = cast(str, query_params.get("client_id"))
client_id_or_short_id: str = cast(str, client_id or utils.shortid()[:10])
sql_editor_id: str = cast(str, query_params.get("sql_editor_id"))
tab_name: str = cast(str, query_params.get("tab"))
status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
user_id: int = g.user.get_id() if g.user else None
session = db.session()
@@ -2620,7 +2595,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
query = (
session.query(Query)
.filter_by(
client_id=client_id, user_id=user_id, sql_editor_id=sql_editor_id
client_id=execution_context.client_id,
user_id=execution_context.user_id,
sql_editor_id=execution_context.sql_editor_id,
)
.one_or_none()
)
@@ -2635,7 +2612,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
return json_success(payload)
mydb = session.query(Database).get(database_id)
mydb = session.query(Database).get(execution_context.database_id)
if not mydb:
raise SupersetGenericErrorException(
__(
@@ -2648,27 +2625,29 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# TODO(bkyryliuk): consider parsing, splitting tmp_schema_name from
# tmp_table_name if user enters
# <schema_name>.<table_name>
tmp_schema_name: Optional[str] = schema
if select_as_cta and mydb.force_ctas_schema:
tmp_schema_name: Optional[str] = execution_context.schema
if execution_context.select_as_cta and mydb.force_ctas_schema:
tmp_schema_name = mydb.force_ctas_schema
elif select_as_cta:
tmp_schema_name = get_cta_schema_name(mydb, g.user, schema, sql)
elif execution_context.select_as_cta:
tmp_schema_name = get_cta_schema_name(
mydb, g.user, execution_context.schema, execution_context.sql
)
# Save current query
query = Query(
database_id=database_id,
sql=sql,
schema=schema,
select_as_cta=select_as_cta,
ctas_method=ctas_method,
database_id=execution_context.database_id,
sql=execution_context.sql,
schema=execution_context.schema,
select_as_cta=execution_context.select_as_cta,
ctas_method=execution_context.ctas_method,
start_time=now_as_float(),
tab_name=tab_name,
status=status,
sql_editor_id=sql_editor_id,
tmp_table_name=tmp_table_name,
tab_name=execution_context.tab_name,
status=execution_context.status,
sql_editor_id=execution_context.sql_editor_id,
tmp_table_name=execution_context.tmp_table_name,
tmp_schema_name=tmp_schema_name,
user_id=user_id,
client_id=client_id_or_short_id,
user_id=execution_context.user_id,
client_id=execution_context.client_id_or_short_id,
)
try:
session.add(query)
@@ -2703,7 +2682,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **template_params
query.sql, **execution_context.template_params
)
except TemplateError as ex:
query.status = QueryStatus.FAILED
@@ -2734,15 +2713,18 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": template_params,
"template_parameters": execution_context.template_params,
},
)
# Limit is not applied to the CTA queries if SQLLAB_CTAS_NO_LIMIT flag is set
# to True.
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and select_as_cta):
if not (config.get("SQLLAB_CTAS_NO_LIMIT") and execution_context.select_as_cta):
# set LIMIT after template processing
limits = [mydb.db_engine_spec.get_limit_from_sql(rendered_query), limit]
limits = [
mydb.db_engine_spec.get_limit_from_sql(rendered_query),
execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]:
query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]:
@@ -2760,7 +2742,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
# Async request.
if async_flag:
if execution_context.is_run_asynchronous():
return self._sql_json_async(
session, rendered_query, query, expand_data, log_params
)