mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
[format] Using Black (#7769)
This commit is contained in:
@@ -18,21 +18,13 @@
|
||||
from contextlib import closing
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_validators.base import (
|
||||
BaseSQLValidator,
|
||||
SQLValidationAnnotation,
|
||||
)
|
||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||
from superset.utils.core import sources
|
||||
|
||||
MAX_ERROR_ROWS = 10
|
||||
@@ -47,15 +39,11 @@ class PrestoSQLValidationError(Exception):
|
||||
class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
"""Validate SQL queries using Presto's built-in EXPLAIN subtype"""
|
||||
|
||||
name = 'PrestoDBSQLValidator'
|
||||
name = "PrestoDBSQLValidator"
|
||||
|
||||
@classmethod
|
||||
def validate_statement(
|
||||
cls,
|
||||
statement,
|
||||
database,
|
||||
cursor,
|
||||
user_name,
|
||||
cls, statement, database, cursor, user_name
|
||||
) -> Optional[SQLValidationAnnotation]:
|
||||
# pylint: disable=too-many-locals
|
||||
db_engine_spec = database.db_engine_spec
|
||||
@@ -64,28 +52,29 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
# pylint: disable=invalid-name
|
||||
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
|
||||
SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR")
|
||||
if SQL_QUERY_MUTATOR:
|
||||
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
|
||||
|
||||
# Transform the final statement to an explain call before sending it on
|
||||
# to presto to validate
|
||||
sql = f'EXPLAIN (TYPE VALIDATE) {sql}'
|
||||
sql = f"EXPLAIN (TYPE VALIDATE) {sql}"
|
||||
|
||||
# Invoke the query against presto. NB this deliberately doesn't use the
|
||||
# engine spec's handle_cursor implementation since we don't record
|
||||
# these EXPLAIN queries done in validation as proper Query objects
|
||||
# in the superset ORM.
|
||||
from pyhive.exc import DatabaseError
|
||||
|
||||
try:
|
||||
db_engine_spec.execute(cursor, sql)
|
||||
polled = cursor.poll()
|
||||
while polled:
|
||||
logging.info('polling presto for validation progress')
|
||||
stats = polled.get('stats', {})
|
||||
logging.info("polling presto for validation progress")
|
||||
stats = polled.get("stats", {})
|
||||
if stats:
|
||||
state = stats.get('state')
|
||||
if state == 'FINISHED':
|
||||
state = stats.get("state")
|
||||
if state == "FINISHED":
|
||||
break
|
||||
time.sleep(0.2)
|
||||
polled = cursor.poll()
|
||||
@@ -107,30 +96,29 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
# we update at some point in the future.
|
||||
if not db_error.args or not isinstance(db_error.args[0], dict):
|
||||
raise PrestoSQLValidationError(
|
||||
'The pyhive presto client returned an unhandled '
|
||||
'database error.',
|
||||
"The pyhive presto client returned an unhandled " "database error."
|
||||
) from db_error
|
||||
error_args: Dict[str, Any] = db_error.args[0]
|
||||
|
||||
# Confirm the two fields we need to be able to present an annotation
|
||||
# are present in the error response -- a message, and a location.
|
||||
if 'message' not in error_args:
|
||||
if "message" not in error_args:
|
||||
raise PrestoSQLValidationError(
|
||||
'The pyhive presto client did not report an error message',
|
||||
"The pyhive presto client did not report an error message"
|
||||
) from db_error
|
||||
if 'errorLocation' not in error_args:
|
||||
if "errorLocation" not in error_args:
|
||||
raise PrestoSQLValidationError(
|
||||
'The pyhive presto client did not report an error location',
|
||||
"The pyhive presto client did not report an error location"
|
||||
) from db_error
|
||||
|
||||
# Pylint is confused about the type of error_args, despite the hints
|
||||
# and checks above.
|
||||
# pylint: disable=invalid-sequence-index
|
||||
message = error_args['message']
|
||||
err_loc = error_args['errorLocation']
|
||||
line_number = err_loc.get('lineNumber', None)
|
||||
start_column = err_loc.get('columnNumber', None)
|
||||
end_column = err_loc.get('columnNumber', None)
|
||||
message = error_args["message"]
|
||||
err_loc = error_args["errorLocation"]
|
||||
line_number = err_loc.get("lineNumber", None)
|
||||
start_column = err_loc.get("columnNumber", None)
|
||||
end_column = err_loc.get("columnNumber", None)
|
||||
|
||||
return SQLValidationAnnotation(
|
||||
message=message,
|
||||
@@ -139,15 +127,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
end_column=end_column,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(f'Unexpected error running validation query: {e}')
|
||||
logging.exception(f"Unexpected error running validation query: {e}")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
sql: str,
|
||||
schema: str,
|
||||
database: Any,
|
||||
cls, sql: str, schema: str, database: Any
|
||||
) -> List[SQLValidationAnnotation]:
|
||||
"""
|
||||
Presto supports query-validation queries by running them with a
|
||||
@@ -160,12 +145,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
parsed_query = ParsedQuery(sql)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
logging.info(f'Validating {len(statements)} statement(s)')
|
||||
logging.info(f"Validating {len(statements)} statement(s)")
|
||||
engine = database.get_sqla_engine(
|
||||
schema=schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
source=sources.get('sql_lab', None),
|
||||
source=sources.get("sql_lab", None),
|
||||
)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
@@ -174,13 +159,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
with closing(conn.cursor()) as cursor:
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(
|
||||
statement,
|
||||
database,
|
||||
cursor,
|
||||
user_name,
|
||||
statement, database, cursor, user_name
|
||||
)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logging.debug(f'Validation found {len(annotations)} error(s)')
|
||||
logging.debug(f"Validation found {len(annotations)} error(s)")
|
||||
|
||||
return annotations
|
||||
|
||||
Reference in New Issue
Block a user