chore(command): Organize Commands according to SIP-92 (#25850)

This commit is contained in:
John Bodley
2023-11-22 11:55:54 -08:00
committed by GitHub
parent 984c278c4c
commit 07bcfa9b5f
265 changed files with 786 additions and 808 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any
from flask_babel import gettext as __
from superset import app, db
from superset.commands.base import BaseCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetTimeoutException
from superset.jinja_context import get_template_processor
from superset.models.core import Database
from superset.sqllab.schemas import EstimateQueryCostSchema
from superset.utils import core as utils
config = app.config
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"]
stats_logger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
class QueryEstimationCommand(BaseCommand):
_database_id: int
_sql: str
_template_params: dict[str, Any]
_schema: str
_database: Database
def __init__(self, params: EstimateQueryCostSchema) -> None:
self._database_id = params.get("database_id")
self._sql = params.get("sql", "")
self._template_params = params.get("template_params", {})
self._schema = params.get("schema", "")
def validate(self) -> None:
self._database = db.session.query(Database).get(self._database_id)
if not self._database:
raise SupersetErrorException(
SupersetError(
message=__("The database could not be found"),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)
def run(
self,
) -> list[dict[str, Any]]:
self.validate()
sql = self._sql
if self._template_params:
template_processor = get_template_processor(self._database)
sql = template_processor.process_template(sql, **self._template_params)
timeout = SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT
timeout_msg = f"The estimation exceeded the {timeout} seconds timeout."
try:
with utils.timeout(seconds=timeout, error_message=timeout_msg):
cost = self._database.db_engine_spec.estimate_query_cost(
self._database, self._schema, sql, utils.QuerySource.SQL_LAB
)
except SupersetTimeoutException as ex:
logger.exception(ex)
raise SupersetErrorException(
SupersetError(
message=__(
"The query estimation was killed after %(sqllab_timeout)s "
"seconds. It might be too complex, or the database might be "
"under heavy load.",
sqllab_timeout=SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT,
),
error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
level=ErrorLevel.ERROR,
),
status=500,
) from ex
spec = self._database.db_engine_spec
query_cost_formatters: dict[str, Any] = app.config[
"QUERY_COST_FORMATTERS_BY_ENGINE"
]
query_cost_formatter = query_cost_formatters.get(
spec.engine, spec.query_cost_formatter
)
cost = query_cost_formatter(cost)
return cost

View File

@@ -0,0 +1,234 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods, too-many-arguments
from __future__ import annotations
import copy
import logging
from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __
from superset.commands.base import BaseCommand
from superset.common.db_query_status import QueryStatus
from superset.daos.exceptions import DAOCreateFailedError
from superset.errors import SupersetErrorType
from superset.exceptions import (
SupersetErrorException,
SupersetErrorsException,
SupersetGenericErrorException,
)
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import (
QueryIsForbiddenToAccessException,
SqlLabException,
)
from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
from superset.sqllab.limiting_factor import LimitingFactor
if TYPE_CHECKING:
from superset.daos.database import DatabaseDAO
from superset.daos.query import QueryDAO
from superset.sqllab.sql_json_executer import SqlJsonExecutor
from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
logger = logging.getLogger(__name__)
CommandResult = dict[str, Any]
class ExecuteSqlCommand(BaseCommand):
_execution_context: SqlJsonExecutionContext
_query_dao: QueryDAO
_database_dao: DatabaseDAO
_access_validator: CanAccessQueryValidator
_sql_query_render: SqlQueryRender
_sql_json_executor: SqlJsonExecutor
_execution_context_convertor: ExecutionContextConvertor
_sqllab_ctas_no_limit: bool
_log_params: dict[str, Any] | None = None
def __init__(
self,
execution_context: SqlJsonExecutionContext,
query_dao: QueryDAO,
database_dao: DatabaseDAO,
access_validator: CanAccessQueryValidator,
sql_query_render: SqlQueryRender,
sql_json_executor: SqlJsonExecutor,
execution_context_convertor: ExecutionContextConvertor,
sqllab_ctas_no_limit_flag: bool,
log_params: dict[str, Any] | None = None,
) -> None:
self._execution_context = execution_context
self._query_dao = query_dao
self._database_dao = database_dao
self._access_validator = access_validator
self._sql_query_render = sql_query_render
self._sql_json_executor = sql_json_executor
self._execution_context_convertor = execution_context_convertor
self._sqllab_ctas_no_limit = sqllab_ctas_no_limit_flag
self._log_params = log_params
def validate(self) -> None:
pass
def run( # pylint: disable=too-many-statements,useless-suppression
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
try:
query = self._try_get_existing_query()
if self.is_query_handled(query):
self._execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
self._execution_context_convertor.set_payload(
self._execution_context, status
)
# save columns into metadata_json
self._query_dao.save_metadata(
self._execution_context.query, self._execution_context_convertor.payload
)
return {
"status": status,
"payload": self._execution_context_convertor.serialize_payload(),
}
except (SupersetErrorException, SupersetErrorsException) as ex:
# to make sure we raising the original
# SupersetErrorsException || SupersetErrorsException
raise ex
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex
def _try_get_existing_query(self) -> Query | None:
return self._query_dao.find_one_or_none(
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
@classmethod
def is_query_handled(cls, query: Query | None) -> bool:
return query is not None and query.status in [
QueryStatus.RUNNING,
QueryStatus.PENDING,
QueryStatus.TIMED_OUT,
]
def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
self._execution_context.set_database(self._get_the_query_db())
query = self._execution_context.create_query()
self._save_new_query(query)
try:
logger.info("Triggering query_id: %i", query.id)
self._execution_context.set_query(query)
rendered_query = self._sql_query_render.render(self._execution_context)
validate_rendered_query = copy.copy(query)
validate_rendered_query.sql = rendered_query
self._validate_access(validate_rendered_query)
self._set_query_limit_if_required(rendered_query)
self._query_dao.update(
query, {"limit": self._execution_context.query.limit}
)
return self._sql_json_executor.execute(
self._execution_context, rendered_query, self._log_params
)
except Exception as ex:
self._query_dao.update(query, {"status": QueryStatus.FAILED})
raise ex
def _get_the_query_db(self) -> Database:
mydb: Any = self._database_dao.find_by_id(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb
@classmethod
def _validate_query_db(cls, database: Database | None) -> None:
if not database:
raise SupersetGenericErrorException(
__(
"The database referenced in this query was not found. Please "
"contact an administrator for further assistance or try again."
)
)
def _save_new_query(self, query: Query) -> None:
try:
self._query_dao.create(query)
except DAOCreateFailedError as ex:
raise SqlLabException(
self._execution_context,
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
"The query record was not created as expected",
ex,
"Please contact an administrator for further assistance or try again.",
) from ex
def _validate_access(self, query: Query) -> None:
try:
self._access_validator.validate(query)
except Exception as ex:
raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex
def _set_query_limit_if_required(
self,
rendered_query: str,
) -> None:
if self._is_required_to_set_limit():
self._set_query_limit(rendered_query)
def _is_required_to_set_limit(self) -> bool:
return not (
self._sqllab_ctas_no_limit and self._execution_context.select_as_cta
)
def _set_query_limit(self, rendered_query: str) -> None:
db_engine_spec = self._execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
self._execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
self._execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
self._execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
self._execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
self._execution_context.query.limit = min(
lim for lim in limits if lim is not None
)
class CanAccessQueryValidator:
def validate(self, query: Query) -> None:
raise NotImplementedError()
class SqlQueryRender:
def render(self, execution_context: SqlJsonExecutionContext) -> str:
raise NotImplementedError()

View File

@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, cast, TypedDict
import pandas as pd
from flask_babel import gettext as __
from superset import app, db, results_backend, results_backend_use_msgpack
from superset.commands.base import BaseCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils import core as utils, csv
from superset.views.utils import _deserialize_results_payload
config = app.config
logger = logging.getLogger(__name__)
class SqlExportResult(TypedDict):
query: Query
count: int
data: list[Any]
class SqlResultExportCommand(BaseCommand):
_client_id: str
_query: Query
def __init__(
self,
client_id: str,
) -> None:
self._client_id = client_id
def validate(self) -> None:
self._query = (
db.session.query(Query).filter_by(client_id=self._client_id).one_or_none()
)
if self._query is None:
raise SupersetErrorException(
SupersetError(
message=__(
"The query associated with these results could not be found. "
"You need to re-run the original query."
),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)
try:
self._query.raise_for_access()
except SupersetSecurityException as ex:
raise SupersetErrorException(
SupersetError(
message=__("Cannot access the query"),
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
),
status=403,
) from ex
def run(
self,
) -> SqlExportResult:
self.validate()
blob = None
if results_backend and self._query.results_key:
logger.info(
"Fetching CSV from results backend [%s]", self._query.results_key
)
blob = results_backend.get(self._query.results_key)
if blob:
logger.info("Decompressing")
payload = utils.zlib_decompress(
blob, decode=not results_backend_use_msgpack
)
obj = _deserialize_results_payload(
payload, self._query, cast(bool, results_backend_use_msgpack)
)
df = pd.DataFrame(
data=obj["data"],
dtype=object,
columns=[c["name"] for c in obj["columns"]],
)
logger.info("Using pandas to convert to CSV")
else:
logger.info("Running a query to turn into CSV")
if self._query.select_sql:
sql = self._query.select_sql
limit = None
else:
sql = self._query.executed_sql
limit = ParsedQuery(sql).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
LimitingFactor.QUERY_AND_DROPDOWN,
}:
# remove extra row from `increased_limit`
limit -= 1
df = self._query.database.get_df(sql, self._query.schema)[:limit]
csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"])
return {
"query": self._query,
"count": len(df.index),
"data": csv_data,
}

View File

@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, cast
from flask_babel import gettext as __
from superset import app, db, results_backend, results_backend_use_msgpack
from superset.commands.base import BaseCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SerializationError, SupersetErrorException
from superset.models.sql_lab import Query
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
from superset.utils.dates import now_as_float
from superset.views.utils import _deserialize_results_payload
config = app.config
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = config["SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT"]
stats_logger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
class SqlExecutionResultsCommand(BaseCommand):
_key: str
_rows: int | None
_blob: Any
_query: Query
def __init__(
self,
key: str,
rows: int | None = None,
) -> None:
self._key = key
self._rows = rows
def validate(self) -> None:
if not results_backend:
raise SupersetErrorException(
SupersetError(
message=__("Results backend is not configured."),
error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR,
level=ErrorLevel.ERROR,
)
)
read_from_results_backend_start = now_as_float()
self._blob = results_backend.get(self._key)
stats_logger.timing(
"sqllab.query.results_backend_read",
now_as_float() - read_from_results_backend_start,
)
if not self._blob:
raise SupersetErrorException(
SupersetError(
message=__(
"Data could not be retrieved from the results backend. You "
"need to re-run the original query."
),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=410,
)
self._query = (
db.session.query(Query).filter_by(results_key=self._key).one_or_none()
)
if self._query is None:
raise SupersetErrorException(
SupersetError(
message=__(
"The query associated with these results could not be found. "
"You need to re-run the original query."
),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)
def run(
self,
) -> dict[str, Any]:
"""Runs arbitrary sql and returns data as json"""
self.validate()
payload = utils.zlib_decompress(
self._blob, decode=not results_backend_use_msgpack
)
try:
obj = _deserialize_results_payload(
payload, self._query, cast(bool, results_backend_use_msgpack)
)
except SerializationError as ex:
raise SupersetErrorException(
SupersetError(
message=__(
"Data could not be deserialized from the results backend. The "
"storage format might have changed, rendering the old data "
"stake. You need to re-run the original query."
),
error_type=SupersetErrorType.RESULTS_BACKEND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
) from ex
if self._rows:
obj = apply_display_max_row_configuration_if_require(obj, self._rows)
return obj